diff --git a/.gitignore b/.gitignore index 08f20946..7433105b 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,4 @@ *.app /build -/third_party/* \ No newline at end of file +/third_party/* diff --git a/.gitmodules b/.gitmodules index c7d36aaa..aeca9a82 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,6 @@ [submodule "third_party/boringssl"] path = third_party/boringssl url = https://github.com/google/boringssl.git +[submodule "third_party/rapidjson"] + path = third_party/rapidjson + url = https://github.com/Tencent/rapidjson.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 06bf900a..b0303970 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ INCLUDE_DIRECTORIES( third_party/spdlog/include third_party/abseil-cpp third_party/boringssl/include + third_party/rapidjson/include ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_BINARY_DIR} ${CMAKE_BINARY_DIR}/quiche @@ -43,8 +44,21 @@ SET(BASE_SRCS base/files/file_util.cc base/files/file_util_posix.cc base/strings/stringprintf.cc + base/bvc-qlog/src/qlogger_types.h + base/bvc-qlog/src/qlogger_types.cc + base/bvc-qlog/src/qlogger_constants.h + base/bvc-qlog/src/qlogger_constants.cc + base/bvc-qlog/src/qlogger.h + base/bvc-qlog/src/qlogger.cc + base/bvc-qlog/src/base_qlogger.h + base/bvc-qlog/src/base_qlogger.cc + base/bvc-qlog/src/file_qlogger.h + base/bvc-qlog/src/file_qlogger.cc + base/bvc-qlog/src/bvc_quic_connection_debug_visitor.h + base/bvc-qlog/src/bvc_quic_connection_debug_visitor.cc ) + SET(NET_SRCS net/base/io_buffer.cc net/http/http_util.cc @@ -65,7 +79,6 @@ SET(PLATFORM_SRCS platform/quiche_platform_impl/quic_mutex_impl.cc platform/quic_platform_impl/quic_cert_utils_impl.cc platform/quic_platform_impl/quic_default_proof_providers_impl.cc - platform/quic_platform_impl/quic_file_utils_impl.cc platform/quic_platform_impl/quic_hostname_utils_impl.cc platform/quic_platform_impl/quic_mem_slice_impl.cc platform/quic_platform_impl/quic_mem_slice_span_impl.cc @@ -74,6 +87,9 @@ SET(PLATFORM_SRCS ) SET(GQUICHE_SRCS + gquiche/common/platform/api/quiche_file_utils.cc + gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc + gquiche/common/quiche_text_utils.cc gquiche/common/quiche_data_reader.cc gquiche/common/quiche_data_writer.cc gquiche/http2/decoder/decode_buffer.cc @@ -184,6 +200,8 @@ SET(GQUICHE_SRCS gquiche/quic/core/crypto/tls_connection.cc gquiche/quic/core/crypto/tls_server_connection.cc gquiche/quic/core/crypto/transport_parameters.cc + gquiche/quic/core/crypto/quic_client_session_cache.cc + gquiche/quic/core/crypto/client_proof_source.cc gquiche/quic/core/frames/quic_ack_frame.cc gquiche/quic/core/frames/quic_ack_frequency_frame.cc gquiche/quic/core/frames/quic_blocked_frame.cc @@ -207,6 +225,7 @@ SET(GQUICHE_SRCS gquiche/quic/core/frames/quic_stream_frame.cc gquiche/quic/core/frames/quic_streams_blocked_frame.cc gquiche/quic/core/frames/quic_window_update_frame.cc + gquiche/quic/core/http/capsule.cc gquiche/quic/core/http/http_constants.cc gquiche/quic/core/http/http_decoder.cc gquiche/quic/core/http/http_encoder.cc @@ -322,12 +341,15 @@ SET(GQUICHE_SRCS gquiche/quic/core/tls_server_handshaker.cc gquiche/quic/core/uber_quic_stream_id_manager.cc gquiche/quic/core/uber_received_packet_manager.cc - gquiche/quic/core/web_transport_stream_adapter.cc + gquiche/quic/core/quic_chaos_protector.cc + gquiche/quic/core/quic_connection_context.cc + gquiche/quic/core/http/web_transport_stream_adapter.cc gquiche/quic/platform/api/quic_file_utils.cc gquiche/quic/platform/api/quic_hostname_utils.cc gquiche/quic/platform/api/quic_ip_address.cc gquiche/quic/platform/api/quic_mutex.cc gquiche/quic/platform/api/quic_socket_address.cc + gquiche/quic/platform/api/quic_mem_slice_storage.cc gquiche/spdy/core/hpack/hpack_constants.cc gquiche/spdy/core/hpack/hpack_decoder_adapter.cc gquiche/spdy/core/hpack/hpack_encoder.cc @@ -353,6 +375,10 @@ SET(GQUICHE_SRCS SET(QUIC_TOOLS_SRCS gquiche/quic/tools/simple_ticket_crypter.cc + gquiche/quic/tools/quic_spdy_client_base.cc + gquiche/quic/tools/quic_client_base.cc + gquiche/quic/tools/quic_client.cc + gquiche/quic/tools/quic_client_epoll_network_helper.cc ) message("PROTO_SRCS = ${CoreProtoSource}") diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index c28ef9f7..26ef4ac8 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -1,12 +1,12 @@ # Compilation **1. Prerequisite** - +> export GOPROXY=https://goproxy.io > apt-get install git cmake build-essential protobuf-compiler libprotobuf-dev golang-go libunwind-dev libicu-dev **2. Build** -> git clone https://github.com/bilibili/quiche.git && cd quiche -> git submodule update --init +> git clone git@git.bilibili.co:quic/quiche.git && cd quiche +> git submodule update --init --recursive > mkdir -p build > cd build && cmake .. > make -j @@ -21,18 +21,18 @@ > git clone https://quiche.googlesource.com/quiche google_quiche > git clone https://quiche.googlesource.com/googleurl -**2. Select a proper tag from chromium and find out quiche version and boringssl version it depended on, then checkout it.** +**2. Select a proper tag from chromium and find out quiche version and boringssl version and absl version it depended on, then checkout it.** > git checkout [commit-id] **3. Rewrite** -> cp -fr google_quiche/* quiche/ +> cp -fr google_quiche/* quiche/gquiche > cp -fr googleurl/* quiche/googleurl/ > cd quiche && bash utils/google_quiche_rewrite.sh **4. Check if any file or dir should be checkouted** > git checkout README.md -**5. Update VERSION, log chromium tag and quiche/boringssl version** +**5. Update VERSION, log chromium tag and quiche/boringssl version and quiche/abseil-cpp version** **6. Compile and fix errors** > Repeate Compilation steps and fix errors. diff --git a/README.md b/README.md index 119c6411..01102b94 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Google quiche is used in Chromium (http://www.chromium.org/quic) project. This r - Easy building with cmake - Only support Linux platform - Easy to keep pace with Google quiche upgrading +- Support Qlog: APIs to expose transport and server statistics for debuggability with qvis ### Source Layout - `base`: Implementation of basic platform functions diff --git a/VERSION b/VERSION index 0f1305ea..4a8408b8 100644 --- a/VERSION +++ b/VERSION @@ -1,5 +1,5 @@ ->>> VERSION LIST -chromium_tag: 92.0.4484.8 -quiche_version: 2a62c51 -boringssl_version: 1596137 +chromium_tag: 98.0.4758.81 +quiche_version: 4003b67 +boringssl_version: 3a667d1 googleurl_version: 4c7645e +absl_version: 9336be0 diff --git a/base/bvc-qlog/src/base_qlogger.cc b/base/bvc-qlog/src/base_qlogger.cc new file mode 100644 index 00000000..b50ed5df --- /dev/null +++ b/base/bvc-qlog/src/base_qlogger.cc @@ -0,0 +1,373 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +#include "base/bvc-qlog/src/base_qlogger.h" +#include "gquiche/quic/core/quic_types.h" +#include "platform/quiche_platform_impl/quiche_text_utils_impl.h" +#include "platform/spdy_platform_impl/spdy_string_utils_impl.h" + +namespace quic { + +std::unique_ptr BaseQLogger::createPacketEventImpl( + const QuicPacketHeader& packetHeader, + uint64_t packetSize, + bool isPacketRecvd) { + auto event = std::make_unique(); + event->refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + event->packetNum = packetHeader.packet_number.ToUint64(); + event->packetSize = packetSize; + event->eventType = QLogEventType::PacketReceived; + summary_.totalPacketsRecvd++; + summary_.totalBytesRecvd += packetSize; + + if (packetHeader.form == IETF_QUIC_SHORT_HEADER_PACKET) { + event->packetType = std::string(kShortHeaderPacketType); + } else if (packetHeader.form == GOOGLE_QUIC_PACKET) { + event->packetType = std::string(kGooglePacketType); + } else { + event->packetType = + std::string(toQlogString(packetHeader.long_packet_type)); + } + return event; +} + +void BaseQLogger::addPacketFrameImpl( + QLogPacketEvent* event, + QuicFrameType frame_type, + void* frame, + bool isPacketRecvd) { + switch (frame_type) { + // Stream Frame + case QuicFrameType::STREAM_FRAME: { + QuicStreamFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->stream_id, f->offset, f->data_length, f->fin)); + break; + } + // Ack Frame + case QuicFrameType::ACK_FRAME: { + QuicAckFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->packets, f->ack_delay_time.ToMicroseconds())); + break; + } + // Padding Frame + case QuicFrameType::PADDING_FRAME: { + event->frames.push_back(std::make_unique()); + break; + } + //Reset Stream Frame + case QuicFrameType::RST_STREAM_FRAME: { + QuicRstStreamFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->stream_id, f->error_code, f->byte_offset)); + break; + } + // Connection Close Frame + case QuicFrameType::CONNECTION_CLOSE_FRAME: { + QuicConnectionCloseFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->close_type, f->wire_error_code, f->quic_error_code, f->error_details, f->transport_close_frame_type)); + break; + } + // Goaway Frame + case QuicFrameType::GOAWAY_FRAME: { + QuicGoAwayFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->error_code, f->last_good_stream_id, f->reason_phrase)); + break; + } + // Window Update Frame + case QuicFrameType::WINDOW_UPDATE_FRAME: { + QuicWindowUpdateFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->stream_id, f->max_data)); + break; + } + // Blocked Frame + case QuicFrameType::BLOCKED_FRAME: { + QuicBlockedFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->stream_id)); + break; + } + // Stop Waiting Frame + case QuicFrameType::STOP_WAITING_FRAME: { + event->frames.push_back(std::make_unique()); + break; + } + // Ping Frame + case QuicFrameType::PING_FRAME: { + event->frames.push_back(std::make_unique()); + break; + } + // Crypto Frame + case QuicFrameType::CRYPTO_FRAME: { + QuicCryptoFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->level, f->offset, f->data_length)); + break; + } + // Handshake Done Frame + case QuicFrameType::HANDSHAKE_DONE_FRAME: { + event->frames.push_back(std::make_unique()); + break; + } + // MTU Discovery Frame + case QuicFrameType::MTU_DISCOVERY_FRAME: { + event->frames.push_back(std::make_unique()); + break; + } + // New Connection ID Frame + case QuicFrameType::NEW_CONNECTION_ID_FRAME: { + QuicNewConnectionIdFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->connection_id.ToString(), f->sequence_number)); + break; + } + // Max Streams Frame + case QuicFrameType::MAX_STREAMS_FRAME: { + QuicMaxStreamsFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->stream_count, f->unidirectional)); + break; + } + // Streams Blocked Frame + case QuicFrameType::STREAMS_BLOCKED_FRAME: { + QuicStreamsBlockedFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->stream_count, f->unidirectional)); + break; + } + // Path Response Frame + case QuicFrameType::PATH_RESPONSE_FRAME: { + QuicPathResponseFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + spdy::SpdyHexEncodeImpl(reinterpret_cast(f->data_buffer.data()), f->data_buffer.size()))); + break; + } + // Path Challenge Frame + case QuicFrameType::PATH_CHALLENGE_FRAME: { + QuicPathChallengeFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique(spdy::SpdyHexEncodeImpl( + reinterpret_cast(f->data_buffer.data()), + f->data_buffer.size()))); + break; + } + // Stop Sending Frame + case QuicFrameType::STOP_SENDING_FRAME: { + QuicStopSendingFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique(f->stream_id, f->error_code)); + break; + } + // Message Frame + case QuicFrameType::MESSAGE_FRAME: { + QuicMessageFrame* f = static_cast(frame); + event->frames.push_back( + std::make_unique(f->message_id, f->message_length)); + break; + } + // New Token Frame + case QuicFrameType::NEW_TOKEN_FRAME: { + event->frames.push_back(std::make_unique()); + break; + } + // Retire Connection ID Frame + case QuicFrameType::RETIRE_CONNECTION_ID_FRAME: { + QuicRetireConnectionIdFrame* f = static_cast(frame); + event->frames.push_back( + std::make_unique(f->sequence_number)); + break; + } + // Ack Frequency Frame + case QuicFrameType::ACK_FREQUENCY_FRAME: { + QuicAckFrequencyFrame* f = static_cast(frame); + event->frames.push_back(std::make_unique( + f->sequence_number, f->packet_tolerance, + f->max_ack_delay.ToMilliseconds(), f->ignore_order)); + break; + } + // Num Frame Types + case QuicFrameType::NUM_FRAME_TYPES: + default: { + break; + } + } +} + +std::unique_ptr BaseQLogger::createPacketEventImpl( + const QuicVersionNegotiationPacket& versionNegotiationPacket, + uint64_t packetSize, + bool isPacketRecvd) { + auto event = std::make_unique(); + event->refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + event->packetSize = packetSize; + event->eventType = + isPacketRecvd ? QLogEventType::PacketReceived : QLogEventType::PacketSent; + event->packetType = kVersionNegotiationPacketType; + event->versionLog = std::make_unique( + VersionNegotiationLog(versionNegotiationPacket.versions)); + summary_.totalPacketsRecvd++; + summary_.totalBytesRecvd += packetSize; + return event; +} + +std::unique_ptr BaseQLogger::createPacketEventImpl( + const std::string newConnectionId, + uint64_t packetSize, + bool isPacketRecvd) { + auto event = std::make_unique(); + event->refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + event->packetSize = packetSize; + event->eventType = + isPacketRecvd ? QLogEventType::PacketReceived : QLogEventType::PacketSent; + event->packetType = kQuicPublicResetPacketType; + summary_.totalPacketsRecvd++; + summary_.totalBytesRecvd += packetSize; + return event; +} + +std::unique_ptr BaseQLogger::createPacketEventImpl( + const QuicPublicResetPacket& publicResetPacket, + uint64_t packetSize, + bool isPacketRecvd) { + auto event = std::make_unique(); + event->refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + event->packetSize = packetSize; + event->eventType = + isPacketRecvd ? QLogEventType::PacketReceived : QLogEventType::PacketSent; + event->packetType = kQuicPublicResetPacketType; + summary_.totalPacketsRecvd++; + summary_.totalBytesRecvd += packetSize; + return event; +} + +std::unique_ptr BaseQLogger::createPacketEventImpl( + uint64_t packet_number, + uint64_t packet_length, + TransmissionType transmission_type, + EncryptionLevel encryption_level, + const QuicFrames& retransmittable_frames, + const QuicFrames& nonretransmittable_frames, + bool isPacketRecvd) { + auto event = std::make_unique(); + event->refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + event->packetNum = packet_number; + event->packetSize = packet_length; + event->eventType = + isPacketRecvd ? QLogEventType::PacketReceived : QLogEventType::PacketSent; + if (encryption_level >= ENCRYPTION_FORWARD_SECURE) { + event->packetType = std::string(kShortHeaderPacketType); + } else { + event->packetType = + std::string(toQlogString(encryptionLevelToLongHeaderType(encryption_level))); + } + event->transmissionType = + std::string(toQlogString(transmission_type)); + + summary_.totalPacketsSent++; + summary_.totalBytesSent += packet_length; + + for (const QuicFrame& frame : retransmittable_frames) { + void* fv = getFrameType(frame); + if (fv != NULL) { + addPacketFrameImpl(event.get(), frame.type, fv, isPacketRecvd); + } + } + for (const QuicFrame& frame : nonretransmittable_frames) { + void* fv = getFrameType(frame); + if (fv != NULL) { + addPacketFrameImpl(event.get(), frame.type, fv, isPacketRecvd); + } + } + return event; +} + +void* BaseQLogger::getFrameType(const QuicFrame& frame) { + void* fv; + switch (frame.type) { + case QuicFrameType::STREAM_FRAME: + fv = (void*)&frame.stream_frame; + break; + case QuicFrameType::RST_STREAM_FRAME: + fv = (void*)frame.rst_stream_frame; + break; + case QuicFrameType::CONNECTION_CLOSE_FRAME: + fv = (void*)frame.connection_close_frame; + break; + case QuicFrameType::WINDOW_UPDATE_FRAME: + fv = (void*)frame.window_update_frame; + break; + case QuicFrameType::BLOCKED_FRAME: + fv = (void*)frame.blocked_frame; + break; + case QuicFrameType::PING_FRAME: + fv = (void*)&frame.ping_frame; + break; + case QuicFrameType::HANDSHAKE_DONE_FRAME: + fv = (void*)&frame.handshake_done_frame; + break; + case QuicFrameType::ACK_FREQUENCY_FRAME: + fv = (void*)frame.ack_frequency_frame; + break; + case QuicFrameType::PADDING_FRAME: + fv = (void*)&frame.padding_frame; + break; + case QuicFrameType::MTU_DISCOVERY_FRAME: + fv = (void*)&frame.mtu_discovery_frame; + break; + case QuicFrameType::STOP_WAITING_FRAME: + fv = (void*)&frame.stop_waiting_frame; + break; + case QuicFrameType::ACK_FRAME: + fv = (void*)frame.ack_frame; + break; + + // New IETF frames, not used in current gQUIC version. + case QuicFrameType::NEW_CONNECTION_ID_FRAME: + fv = (void*)frame.new_connection_id_frame; + break; + case QuicFrameType::RETIRE_CONNECTION_ID_FRAME: + fv = (void*)frame.retire_connection_id_frame; + break; + case QuicFrameType::MAX_STREAMS_FRAME: + fv = (void*)&frame.max_streams_frame; + break; + case QuicFrameType::STREAMS_BLOCKED_FRAME: + fv = (void*)&frame.streams_blocked_frame; + break; + case QuicFrameType::PATH_RESPONSE_FRAME: + fv = (void*)frame.path_response_frame; + break; + case QuicFrameType::PATH_CHALLENGE_FRAME: + fv = (void*)frame.path_challenge_frame; + break; + case QuicFrameType::STOP_SENDING_FRAME: + fv = (void*)frame.stop_sending_frame; + break; + case QuicFrameType::MESSAGE_FRAME: + fv = (void*)frame.message_frame; + break; + case QuicFrameType::CRYPTO_FRAME: + fv = (void*)frame.crypto_frame; + break; + case QuicFrameType::NEW_TOKEN_FRAME: + fv = (void*)frame.new_token_frame; + break; + // Ignore gQUIC-specific frames. + case QuicFrameType::GOAWAY_FRAME: + fv = (void*)frame.goaway_frame; + break; + case QuicFrameType::NUM_FRAME_TYPES: + default: + fv = NULL; + break; + } + return fv; +} + +} // namespace quic diff --git a/base/bvc-qlog/src/base_qlogger.h b/base/bvc-qlog/src/base_qlogger.h new file mode 100644 index 00000000..3f91608f --- /dev/null +++ b/base/bvc-qlog/src/base_qlogger.h @@ -0,0 +1,84 @@ +#pragma once + +#include "base/bvc-qlog/src/qlogger.h" +#include "base/bvc-qlog/src/qlogger_constants.h" +#include "base/bvc-qlog/src/qlogger_types.h" + +namespace quic { + +class BaseQLogger : public QLogger { + public: + explicit BaseQLogger(VantagePoint vantagePointIn, std::string protocolTypeIn) + : QLogger(vantagePointIn, std::move(protocolTypeIn)) {} + + ~BaseQLogger() override = default; + + protected: + std::unique_ptr createPacketEventImpl( + const std::string newConnectionId, + uint64_t packetSize, + bool isPacketRecvd); + + std::unique_ptr createPacketEventImpl( + const QuicPublicResetPacket& publicResetPacket, + uint64_t packetSize, + bool isPacketRecvd); + + std::unique_ptr createPacketEventImpl( + const QuicVersionNegotiationPacket& versionNegotiationPacket, + uint64_t packetSize, + bool isPacketRecvd); + + std::unique_ptr createPacketEventImpl( + const QuicPacketHeader& packetHeader, + uint64_t packetSize, + bool isPacketRecvd); + + void addFramesProcessedImpl( + QLogFramesProcessed* event, + QuicFrameType frame_type, + void* frame, + uint64_t packet_number, + uint64_t packet_size, + std::string packet_type, + std::chrono::microseconds time_dirft); + + void addPacketFrameImpl( + QLogPacketEvent* event, + QuicFrameType frame_type, + void* frame, + bool isPacketRecvd); + + std::unique_ptr createPacketEventImpl( + uint64_t packet_number, + uint64_t packet_length, + TransmissionType transmission_type, + EncryptionLevel encryption_level, + const QuicFrames& retransmittable_frames, + const QuicFrames& nonretransmittable_frames, + bool isPacketRecvd, + bool aggregate); + + std::unique_ptr createPacketEventImpl( + uint64_t packet_number, + uint64_t packet_length, + TransmissionType transmission_type, + EncryptionLevel encryption_level, + const QuicFrames& retransmittable_frames, + const QuicFrames& nonretransmittable_frames, + bool isPacketRecvd); + + void* getFrameType(const quic::QuicFrame& frame); + + std::chrono::microseconds steady_startTime_ = + std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()); + std::chrono::microseconds steady_packetTime_ = + std::chrono::duration_cast(std::chrono::steady_clock::now().time_since_epoch()); + std::chrono::microseconds system_startTime_ = + std::chrono::duration_cast(std::chrono::system_clock::now().time_since_epoch()); + std::chrono::microseconds endTime_; + + //* for frame aggregation holder, whether or not place it here is a problem. + +}; +} // namespace quic diff --git a/base/bvc-qlog/src/bvc_quic_connection_debug_visitor.cc b/base/bvc-qlog/src/bvc_quic_connection_debug_visitor.cc new file mode 100644 index 00000000..72eb4be7 --- /dev/null +++ b/base/bvc-qlog/src/bvc_quic_connection_debug_visitor.cc @@ -0,0 +1,296 @@ +#include "bvc_quic_connection_debug_visitor.h" +#include "base/bvc-qlog/src/qlogger_constants.h" +#include "base/bvc-qlog/src/file_qlogger.h" +#include "gquiche/quic/core/quic_utils.h" +#include "gquiche/quic/platform/api/quic_socket_address.h" +#include "gquiche/quic/core/congestion_control/bbr_sender.h" + +using namespace quic; + +namespace bvc { + +BvcQuicConnectionDebugVisitor::BvcQuicConnectionDebugVisitor( + FileQLogger* qlogger, QuicConnection* connection) + : qlogger_(qlogger), + connection_(connection) { +} + +void BvcQuicConnectionDebugVisitor::OnPacketReceived( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + const QuicEncryptedPacket& packet) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } + packet_length_ = packet.length(); +} + + +void BvcQuicConnectionDebugVisitor::OnPacketSent( + QuicPacketNumber packet_number, + QuicPacketLength packet_length, + bool has_crypto_handshake, + TransmissionType transmission_type, + EncryptionLevel encryption_level, + const QuicFrames& retransmittable_frames, + const QuicFrames& nonretransmittable_frames, + QuicTime sent_time) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } + (qlogger_)->addPacket(packet_number.ToUint64(), packet_length, transmission_type, + encryption_level, retransmittable_frames, nonretransmittable_frames, false); +} + +void BvcQuicConnectionDebugVisitor::OnPublicResetPacket(const QuicPublicResetPacket& packet) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } + (qlogger_)->addPacket(packet, packet_length_, true); +} + +void BvcQuicConnectionDebugVisitor::OnVersionNegotiationPacket(const QuicVersionNegotiationPacket& packet) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } + (qlogger_)->addPacket(packet, packet_length_, true); +} + +void BvcQuicConnectionDebugVisitor::OnPacketHeader( + const QuicPacketHeader& header, + QuicTime receive_time, + EncryptionLevel level) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } + if (VersionHasIetfQuicFrames(header.version.transport_version)) { + // packet has IETF quic frames + } else { + // packet has google quic frames + } + current_event_ = (qlogger_)->createPacketEvent(header, packet_length_, true); +} + +void BvcQuicConnectionDebugVisitor::OnPacketComplete() { + if (qlogger_ == NULL || connection_ == NULL || current_event_.get() == NULL) { + return; + } + (qlogger_)->finishCreatePacketEvent(std::move(current_event_)); +} + +void BvcQuicConnectionDebugVisitor::OnConnectionClosed( + const QuicConnectionCloseFrame& frame, + ConnectionCloseSource source) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } + + (qlogger_)->addConnectionClose(frame.quic_error_code, frame.error_details, source); +} + +void BvcQuicConnectionDebugVisitor::OnStreamFrame(const QuicStreamFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::STREAM_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnCryptoFrame(const QuicCryptoFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::CRYPTO_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnStopWaitingFrame(const QuicStopWaitingFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::STOP_WAITING_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnPaddingFrame(const QuicPaddingFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::PADDING_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnPingFrame(const QuicPingFrame& frame, QuicTime::Delta) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::PING_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnGoAwayFrame(const QuicGoAwayFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::GOAWAY_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnRstStreamFrame(const QuicRstStreamFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::RST_STREAM_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnConnectionCloseFrame( + const QuicConnectionCloseFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::CONNECTION_CLOSE_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnWindowUpdateFrame( + const QuicWindowUpdateFrame& frame, const QuicTime& receive_time) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::WINDOW_UPDATE_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnBlockedFrame(const QuicBlockedFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::BLOCKED_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnHandshakeDoneFrame(const QuicHandshakeDoneFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::HANDSHAKE_DONE_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnNewConnectionIdFrame( + const QuicNewConnectionIdFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::NEW_CONNECTION_ID_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::MAX_STREAMS_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnStreamsBlockedFrame(const QuicStreamsBlockedFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::STREAMS_BLOCKED_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnPathResponseFrame(const QuicPathResponseFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::PATH_RESPONSE_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnPathChallengeFrame(const QuicPathChallengeFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::PATH_CHALLENGE_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnStopSendingFrame(const QuicStopSendingFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::STOP_SENDING_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnMessageFrame(const QuicMessageFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::MESSAGE_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnNewTokenFrame(const QuicNewTokenFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::NEW_TOKEN_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnRetireConnectionIdFrame(const QuicRetireConnectionIdFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::RETIRE_CONNECTION_ID_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnAckFrequencyFrame(const QuicAckFrequencyFrame& frame) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::ACK_FREQUENCY_FRAME, const_cast(&frame), true); +} + +void BvcQuicConnectionDebugVisitor::OnAckFrameStart( + QuicTime::Delta ack_delay_time) { + if (qlogger_ == NULL || connection_ == NULL || current_event_ == NULL) { + return; + } + ack_frame_.reset(); + ack_frame_ = std::make_unique(); + ack_frame_->ack_delay_time = ack_delay_time; +} + +void BvcQuicConnectionDebugVisitor::OnAckRange( + QuicPacketNumber start, QuicPacketNumber end) { + if (qlogger_ == NULL || connection_ == NULL || + current_event_ == NULL || ack_frame_ == NULL) { + return; + } + //*: Because the interval uses half-closed range `[)` and causes confusion, + //*: minus 1 here to make it closed range '[]' + ack_frame_->packets.AddRange(start, end); +} + +void BvcQuicConnectionDebugVisitor::OnAckFrameEnd( + QuicPacketNumber start) { + if (qlogger_ == NULL || connection_ == NULL || + current_event_ == NULL || ack_frame_ == NULL) { + return; + } + (qlogger_)->addPacketFrame(current_event_.get(), QuicFrameType::ACK_FRAME, ack_frame_.get(), true); +} + +void BvcQuicConnectionDebugVisitor::OnIncomingAck( + QuicPacketNumber /*ack_packet_number*/, + EncryptionLevel /*ack_decrypted_level*/, + const QuicAckFrame& /*ack_frame*/, + QuicTime /*ack_receive_time*/, + QuicPacketNumber /*largest_observed*/, + bool /*rtt_updated*/, + QuicPacketNumber /*least_unacked_sent_packet*/) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } +} + +void BvcQuicConnectionDebugVisitor::OnPacketLoss( + QuicPacketNumber lost_packet_number, + EncryptionLevel encryption_level, + TransmissionType transmission_type, + QuicTime detection_time) { + if (qlogger_ == NULL || connection_ == NULL) { + return; + } + + (qlogger_)->addPacketLost(lost_packet_number.ToUint64(), encryption_level, transmission_type); +} + +} // namespace bvc diff --git a/base/bvc-qlog/src/bvc_quic_connection_debug_visitor.h b/base/bvc-qlog/src/bvc_quic_connection_debug_visitor.h new file mode 100644 index 00000000..ffbbd54a --- /dev/null +++ b/base/bvc-qlog/src/bvc_quic_connection_debug_visitor.h @@ -0,0 +1,233 @@ +#pragma once +#include "base/bvc-qlog/src/file_qlogger.h" +#include "gquiche/quic/core/quic_connection.h" +#include "gquiche/quic/core/quic_default_packet_writer.h" +#include "gquiche/quic/core/quic_packet_writer.h" +#include "gquiche/quic/core/quic_packets.h" +#include "gquiche/quic/core/quic_simple_buffer_allocator.h" +#include "gquiche/quic/core/quic_time.h" +#include "gquiche/quic/platform/api/quic_epoll.h" + +namespace bvc { + +class BvcQuicConnectionDebugVisitor : public quic::QuicConnectionDebugVisitor { + public: + BvcQuicConnectionDebugVisitor(quic::FileQLogger* qlogger, + quic::QuicConnection* connection); + ~BvcQuicConnectionDebugVisitor() override {} + + virtual void OnPacketSent(quic::QuicPacketNumber /*packet_number*/, + quic::QuicPacketLength /*packet_length*/, + bool /*has_crypto_handshake*/, + quic::TransmissionType /*transmission_type*/, + quic::EncryptionLevel /*encryption_level*/, + const quic::QuicFrames& /*retransmittable_frames*/, + const quic::QuicFrames& /*nonretransmittable_frames*/, + quic::QuicTime /*sent_time*/) override; + + // Called when a coalesced packet has been sent. + virtual void OnCoalescedPacketSent( + const quic::QuicCoalescedPacket& /*coalesced_packet*/, + size_t /*length*/) override {} + + // Called when a PING frame has been sent. + virtual void OnPingSent() override {} + + // Called when a packet has been received, but before it is + // validated or parsed. + virtual void OnPacketReceived(const quic::QuicSocketAddress& /*self_address*/, + const quic::QuicSocketAddress& /*peer_address*/, + const quic::QuicEncryptedPacket& /*packet*/) override; + + // Called when the unauthenticated portion of the header has been parsed. + virtual void OnUnauthenticatedHeader(const quic::QuicPacketHeader& /*header*/) override {} + + // Called when a packet is received with a connection id that does not + // match the ID of this connection. + virtual void OnIncorrectConnectionId(quic::QuicConnectionId /*connection_id*/) override {} + + // Called when an undecryptable packet has been received. If |dropped| is + // true, the packet has been dropped. Otherwise, the packet will be queued and + // connection will attempt to process it later. + virtual void OnUndecryptablePacket(quic::EncryptionLevel /*decryption_level*/, + bool /*dropped*/) override {} + + // Called when attempting to process a previously undecryptable packet. + virtual void OnAttemptingToProcessUndecryptablePacket( + quic::EncryptionLevel /*decryption_level*/) override {} + + // Called when a duplicate packet has been received. + virtual void OnDuplicatePacket(quic::QuicPacketNumber /*packet_number*/) override {} + + // Called when the protocol version on the received packet doensn't match + // current protocol version of the connection. + virtual void OnProtocolVersionMismatch(quic::ParsedQuicVersion /*version*/) override {} + + // Called when the complete header of a packet has been parsed. + virtual void OnPacketHeader(const quic::QuicPacketHeader& /*header*/, + quic::QuicTime /*receive_time*/, + quic::EncryptionLevel /*level*/) override; + + // Called when a StreamFrame has been parsed. + virtual void OnStreamFrame(const quic::QuicStreamFrame& /*frame*/) override; + + // Called when a CRYPTO frame containing handshake data is received. + virtual void OnCryptoFrame(const quic::QuicCryptoFrame& /*frame*/) override; + + // Called when a StopWaitingFrame has been parsed. + virtual void OnStopWaitingFrame(const quic::QuicStopWaitingFrame& /*frame*/) override; + + // Called when a QuicPaddingFrame has been parsed. + virtual void OnPaddingFrame(const quic::QuicPaddingFrame& /*frame*/) override; + + // Called when a Ping has been parsed. + virtual void OnPingFrame(const quic::QuicPingFrame& /*frame*/, + quic::QuicTime::Delta /*ping_received_delay*/) override; + + // Called when a GoAway has been parsed. + virtual void OnGoAwayFrame(const quic::QuicGoAwayFrame& /*frame*/) override; + + // Called when a RstStreamFrame has been parsed. + virtual void OnRstStreamFrame(const quic::QuicRstStreamFrame& /*frame*/) override; + + // Called when a ConnectionCloseFrame has been parsed. All forms + // of CONNECTION CLOSE are handled, Google QUIC, IETF QUIC + // CONNECTION CLOSE/Transport and IETF QUIC CONNECTION CLOSE/Application + virtual void OnConnectionCloseFrame( + const quic::QuicConnectionCloseFrame& /*frame*/) override; + + // Called when a WindowUpdate has been parsed. + virtual void OnWindowUpdateFrame(const quic::QuicWindowUpdateFrame& /*frame*/, + const quic::QuicTime& /*receive_time*/) override; + + // Called when a BlockedFrame has been parsed. + virtual void OnBlockedFrame(const quic::QuicBlockedFrame& /*frame*/) override; + + // Called when a NewConnectionIdFrame has been parsed. + virtual void OnNewConnectionIdFrame( + const quic::QuicNewConnectionIdFrame& /*frame*/) override; + + // Called when a RetireConnectionIdFrame has been parsed. + virtual void OnRetireConnectionIdFrame( + const quic::QuicRetireConnectionIdFrame& /*frame*/) override; + + // Called when a NewTokenFrame has been parsed. + virtual void OnNewTokenFrame(const quic::QuicNewTokenFrame& /*frame*/) override; + + // Called when a MessageFrame has been parsed. + virtual void OnMessageFrame(const quic::QuicMessageFrame& /*frame*/) override; + + // Called when a HandshakeDoneFrame has been parsed. + virtual void OnHandshakeDoneFrame(const quic::QuicHandshakeDoneFrame& /*frame*/) override; + + // Called when a public reset packet has been received. + virtual void OnPublicResetPacket(const quic::QuicPublicResetPacket& /*packet*/) override; + + // Called when a version negotiation packet has been received. + virtual void OnVersionNegotiationPacket( + const quic::QuicVersionNegotiationPacket& /*packet*/) override; + + // Called when the connection is closed. + virtual void OnConnectionClosed(const quic::QuicConnectionCloseFrame& /*frame*/, + quic::ConnectionCloseSource /*source*/) override; + + // Called when the version negotiation is successful. + virtual void OnSuccessfulVersionNegotiation( + const quic::ParsedQuicVersion& /*version*/) override {} + + // Called when a CachedNetworkParameters is sent to the client. + virtual void OnSendConnectionState( + const quic::CachedNetworkParameters& /*cached_network_params*/) override {} + + // Called when a CachedNetworkParameters are received from the client. + virtual void OnReceiveConnectionState( + const quic::CachedNetworkParameters& /*cached_network_params*/) override {} + + // Called when the connection parameters are set from the supplied + // |config|. + virtual void OnSetFromConfig(const quic::QuicConfig& /*config*/) override {} + + // Called when RTT may have changed, including when an RTT is read from + // the config. + virtual void OnRttChanged(quic::QuicTime::Delta /*rtt*/) const override {} + + // Called when a StopSendingFrame has been parsed. + virtual void OnStopSendingFrame(const quic::QuicStopSendingFrame& /*frame*/) override; + + // Called when a PathChallengeFrame has been parsed. + virtual void OnPathChallengeFrame(const quic::QuicPathChallengeFrame& /*frame*/) override; + + // Called when a PathResponseFrame has been parsed. + virtual void OnPathResponseFrame(const quic::QuicPathResponseFrame& /*frame*/) override; + + // Called when a StreamsBlockedFrame has been parsed. + virtual void OnStreamsBlockedFrame(const quic::QuicStreamsBlockedFrame& /*frame*/) override; + + // Called when a MaxStreamsFrame has been parsed. + virtual void OnMaxStreamsFrame(const quic::QuicMaxStreamsFrame& /*frame*/) override; + + // Called when |count| packet numbers have been skipped. + virtual void OnNPacketNumbersSkipped(quic::QuicPacketCount /*count*/, + quic::QuicTime /*now*/) override {} + + // Called for QUIC+TLS versions when we send transport parameters. + virtual void OnTransportParametersSent( + const quic::TransportParameters& /*transport_parameters*/) override {} + + // Called for QUIC+TLS versions when we receive transport parameters. + virtual void OnTransportParametersReceived( + const quic::TransportParameters& /*transport_parameters*/) override {} + + // Called for QUIC+TLS versions when we resume cached transport parameters for + // 0-RTT. + virtual void OnTransportParametersResumed( + const quic::TransportParameters& /*transport_parameters*/) override {} + + // Called for QUIC+TLS versions when 0-RTT is rejected. + virtual void OnZeroRttRejected(int /*reject_reason*/) override {} + + // Called for QUIC+TLS versions when 0-RTT packet gets acked. + virtual void OnZeroRttPacketAcked() override {} + + // Called on peer address change. + virtual void OnPeerAddressChange(quic::AddressChangeType /*type*/, + quic::QuicTime::Delta /*connection_time*/) override {} + + // Called when all frames in packet have been parsed. + virtual void OnPacketComplete() override; + + // Called when a ack frequency frame has been parsed.. + virtual void OnAckFrequencyFrame(const quic::QuicAckFrequencyFrame& /*frame*/) override; + + // Called when ack_delay_time in ack frame has been parsed. + virtual void OnAckFrameStart(quic::QuicTime::Delta ack_delay_time) override; + + // Called when ack_range in ack frame has been parsed. + virtual void OnAckRange(quic::QuicPacketNumber start, quic::QuicPacketNumber end) override; + + // Done processing the ack frame. + virtual void OnAckFrameEnd(quic::QuicPacketNumber start) override; + + // Output DebugState of congestion control for analysis. + virtual void OnIncomingAck(quic::QuicPacketNumber /*ack_packet_number*/, + quic::EncryptionLevel /*ack_decrypted_level*/, + const quic::QuicAckFrame& /*ack_frame*/, + quic::QuicTime /*ack_receive_time*/, + quic::QuicPacketNumber /*largest_observed*/, + bool /*rtt_updated*/, + quic::QuicPacketNumber /*least_unacked_sent_packet*/) override; + + virtual void OnPacketLoss(quic::QuicPacketNumber /*lost_packet_number*/, + quic::EncryptionLevel /*encryption_level*/, + quic::TransmissionType /*transmission_type*/, + quic::QuicTime /*detection_time*/) override; + + private: + quic::FileQLogger* qlogger_; + quic::QuicConnection* connection_; + std::unique_ptr current_event_; + std::unique_ptr ack_frame_; + uint64_t packet_length_; + }; +} // namespace bvc + diff --git a/base/bvc-qlog/src/file_qlogger.cc b/base/bvc-qlog/src/file_qlogger.cc new file mode 100644 index 00000000..5e67eede --- /dev/null +++ b/base/bvc-qlog/src/file_qlogger.cc @@ -0,0 +1,637 @@ +#include +#include +#include +#include +#include + +#include "base/bvc-qlog/src/file_qlogger.h" +#include "gquiche/quic/platform/api/quic_logging.h" + +#include "absl/strings/str_cat.h" + +using namespace spdlog; + +namespace quic { + + +void FileQLogger::setDcid(quiche::QuicheOptionalImpl connID) { + if (!connID->IsEmpty()) { + dcid_ = connID; + if (streaming_) { + setFileObject(); + setupStream(); + } + } +} + +void FileQLogger::setScid(quiche::QuicheOptionalImpl connID) { + if (!connID->IsEmpty()) { + scid_ = connID; + } +} + +void FileQLogger::setQuicVersion(const QuicTransportVersion version) { + summary_.quicVersion = version; +} + +void FileQLogger::usedZeroRtt(bool use) { + summary_.usedZeroRtt = use; +} + +void FileQLogger::initialSummary() { + Document initial_summary; + connection_duration_ = (numEvents_ == 0) ? 0 : (double)endTime_.count()/1000; + initial_summary.SetObject(); + auto duration = (numEvents_ == 0) ? (std::chrono::microseconds)0 : (std::chrono::microseconds)endTime_.count()/1000; + initial_summary = generateSummary(numEvents_, duration); + + Document copy_summary; + copy_summary.CopyFrom(initial_summary, copy_summary.GetAllocator()); +} + +void FileQLogger::createBaseJson() { + if (!metadata_head_.empty() && !metadata_head_extra_.empty()) { + return; + } + // Create the base json + Document qLog, traces; + qLog.SetObject(); + traces.SetObject(); + toJsonBase(qLog, traces); + Value& traces_value = qLog["traces"]; + traces_value.PushBack(traces, traces.GetAllocator()); + + StringBuffer buffer; + Writer writer(buffer); + qLog.Accept(writer); + std::string base_Json = buffer.GetString(); + + baseJson_.clear(); + if (prettyJson_) { + baseJson_ << std::setw(4) << base_Json; + } else { + baseJson_ << base_Json; + } + // start copying from base to outputFile, stop at events + metadata_head_.clear(); + metadata_head_extra_.clear(); + baseJson_.seekg(0, baseJson_.beg); + token_ = prettyJson_ ? "\"events\": [" : "\"events\":["; + while (getline(baseJson_, eventLine_)) { + pos_ = eventLine_.find(token_); + if (pos_ == std::string::npos) { + absl::StrAppend(&metadata_head_, eventLine_); + } else { + // Found the token + for (char c : eventLine_) { + // get the amount of spaces each event should be padded + eventsPadding_.clear(); + if (c == ' ') { + eventsPadding_ += ' '; + } else { + break; + } + } + // get metadata + absl::StrAppend(&metadata_head_, eventLine_.substr(0, pos_ + token_.size())); + absl::StrAppend(&metadata_head_extra_, eventLine_.substr(pos_ + token_.size(), eventLine_.size() - pos_ - token_.size() - (prettyJson_ ? 0 : 1)), ","); + break; + } + } +} + +void FileQLogger::setFileObject() { + absl::StrAppend(&path_, "/", dcid_->ToString(), ".qlog"); + if(!spdlog::details::os::path_exists(path_)) { + spdlog::details::os::create_dir(spdlog::details::os::dir_name(path_)); + } + + if(fileObj_.is_open()) { + fileObj_.close(); + } + fileObj_.open(path_, std::fstream::out); +} + +void FileQLogger::setSpdlogObject() { + auto file_sink = std::make_shared(path_, true); + auto formatter = std::make_unique("%v", pattern_time_type::local, std::string("")); + logger_ = std::make_shared(dcid_->ToString(), std::move(file_sink), tp_, async_overflow_policy::block); + logger_->set_formatter(std::move(formatter)); +} + +void FileQLogger::setupStream() { + // create the output file + if (dcid_->IsEmpty()) { + QUIC_LOG(ERROR) << "Error: No dcid found"; + return; + } + endLine_ = prettyJson_ ? "\n" : ""; + initialSummary(); + + if (fileObj_) { + createBaseJson(); + setSpdlogObject(); + logger_->info(metadata_head_); + } +} + +void FileQLogger::finishStream() { + connection_duration_ = (numEvents_ == 0) ? 0 : (double)endTime_.count()/1000; + Document summaryJson = generateSummary(numEvents_, endTime_); + + StringBuffer buffer; + Writer writer(buffer); + summaryJson.Accept(writer); + std::string summary = buffer.GetString(); + + if (fileObj_) { + // logger remaining event + if(log_event_buffer_ != 0 && !logstring_.empty()) { + logger_->info(logstring_); + } + // finish copying the line that was stopped on + std::string tail; + if (!prettyJson_) { + absl::StrAppend(&tail, metadata_head_extra_); + } else { + // copy all the remaining lines but the last one + std::string previousLine = eventsPadding_ + metadata_head_extra_; + while (getline(baseJson_, eventLine_)) { + absl::StrAppend(&tail, endLine_, previousLine); + previousLine = eventLine_; + } + } + std::stringstream summaryBuffer; + std::string line; + if (prettyJson_) { + absl::StrAppend(&tail, basePadding_, "\"summary\" : "); + summaryBuffer << summary; + } else { + absl::StrAppend(&tail, "\"summary\":"); + summaryBuffer << summary; + } + + std::string summaryPadding = ""; + // add padding to every line in the summary except the first + while (getline(summaryBuffer, line)) { + absl::StrAppend(&tail, summaryPadding, line, endLine_); + summaryPadding = basePadding_; + } + absl::StrAppend(&tail, "}"); + logger_->info(tail); + logger_->flush(); + fileObj_.close(); + } +} + +void FileQLogger::handleEvent(std::unique_ptr event) { + if (streaming_) { + ++numEvents_; + Document eventjson = event->toJson(); + + buffer_.Clear(); + writer_.Reset(buffer_); + eventjson.Accept(writer_); + std::string event_json = buffer_.GetString(); + if (fileObj_) { + std::stringstream eventBuffer; + std::string line; + if (prettyJson_) { + eventBuffer << std::setw(4) << event_json; + } else { + eventBuffer << event_json; + } + + if (numEvents_ > 1) { + absl::StrAppend(&logstring_, ","); + } + // add padding to every line in the event + while (getline(eventBuffer, line)) { + absl::StrAppend(&logstring_, endLine_, basePadding_, eventsPadding_, line); + } + + if(log_event_buffer_ == 0 || numEvents_ % log_event_buffer_ == 0) { + logger_->info(logstring_); + logstring_.clear(); + } + } + } else { + logs.push_back(std::move(event)); + } +} + +void FileQLogger::addPacket( + const QuicPublicResetPacket& publicResetPacket, + uint64_t packetSize, + bool isPacketRecvd) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(createPacketEventImpl(publicResetPacket, packetSize, isPacketRecvd)); +} + +void FileQLogger::addPacket( + const QuicVersionNegotiationPacket& versionNegotiationPacket, + uint64_t packetSize, + bool isPacketRecvd) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(createPacketEventImpl(versionNegotiationPacket, packetSize, isPacketRecvd)); +} + +void FileQLogger::addPacket( + uint64_t packet_number, + uint64_t packet_length, + TransmissionType transmission_type, + EncryptionLevel encryption_level, + const QuicFrames& retransmittable_frames, + const QuicFrames& nonretransmittable_frames, + bool isPacketRecvd) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(createPacketEventImpl(packet_number, packet_length, transmission_type, encryption_level, retransmittable_frames, nonretransmittable_frames, isPacketRecvd)); +} + +std::unique_ptr FileQLogger::createPacketEvent( + const QuicPacketHeader& packetHeader, + uint64_t packetSize, + bool isPacketRecvd) { + return createPacketEventImpl(packetHeader, packetSize, isPacketRecvd); +} + +void FileQLogger::addPacketFrame( + QLogPacketEvent* event, + QuicFrameType frame_type, + void* frame, + bool isPacketRecvd) { + // aggregation switch + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + addPacketFrameImpl(event, frame_type, frame, isPacketRecvd); + return; +} + +void FileQLogger::finishCreatePacketEvent(std::unique_ptr event) { + if (event->frames.size() == 0) { + return; + } + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::move(event)); +} + +void FileQLogger::addConnectionClose( + QuicErrorCode error, + const std::string& reason, + ConnectionCloseSource source) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + error_ = QuicErrorCodeToString(error); + reason_ = reason; + source_ = ConnectionCloseSourceToString(source); + handleEvent(std::make_unique( + error, + reason, + source, + refTime)); +} + +void FileQLogger::addBandwidthEstUpdate( + uint64_t bytes, + std::chrono::microseconds interval) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + bytes, + interval, + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_)); +} + +void FileQLogger::addAppLimitedUpdate() { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + true, + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_)); +} + +void FileQLogger::addAppUnlimitedUpdate() { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + false, + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_)); +} + +void FileQLogger::addPacingMetricUpdate( + uint64_t pacingBurstSizeIn, + std::chrono::microseconds pacingIntervalIn) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + pacingBurstSizeIn, pacingIntervalIn, refTime)); +} + +void FileQLogger::addPacingObservation( + std::string& actual, + std::string& expect, + std::string& conclusion) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + actual, expect, conclusion, refTime)); +} + +void FileQLogger::addAppIdleUpdate(std::string& idleEvent, bool idle) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + idleEvent, idle, refTime)); +} + +void FileQLogger::addPacketDrop(size_t packetSize, std::string& dropReason) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + packetSize, dropReason, refTime)); +} + +void FileQLogger::addDatagramReceived(uint64_t dataLen) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent( + std::make_unique(dataLen, refTime)); +} + +void FileQLogger::addLossAlarm( + uint64_t largestSent, + uint64_t alarmCount, + uint64_t outstandingPackets, + std::string& type) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + largestSent, alarmCount, outstandingPackets, type, refTime)); +} + +void FileQLogger::addPacketLost( + uint64_t LostPacketNum, + EncryptionLevel level, + TransmissionType type) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + summary_.totalPacketsLost++; + handleEvent(std::make_unique( + LostPacketNum, level, type, refTime)); +} + +void FileQLogger::addTransportStateUpdate(std::string& update) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + update, refTime)); +} + +void FileQLogger::addPacketBuffered( + uint64_t packetNum, + EncryptionLevel protectionType, + uint64_t packetSize) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + packetNum, protectionType, packetSize, refTime)); +} + +void FileQLogger::addMetricUpdate( + std::chrono::microseconds latestRtt, + std::chrono::microseconds mrtt, + std::chrono::microseconds srtt, + std::chrono::microseconds ackDelay) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + latestRtt, mrtt, srtt, ackDelay, refTime)); +} + +Document FileQLogger::toJson() { + Document j, trace; + j.SetObject(); + trace.SetObject(); + + toJsonBase(j, trace); + Value& traces_value = j["traces"]; + traces_value.PushBack(trace, trace.GetAllocator()); + + + if (logs.size() > 0) { + Document summaryJson = generateSummary(logs.size(), logs.back()->refTime); + j.AddMember("summary", summaryJson, j.GetAllocator()); + } + + // convert stored logs into json event array + Value value; + value.SetArray(); + for (auto& event : logs) { + Value event_json; + event_json.CopyFrom(event->toJson(), j.GetAllocator()); + value.PushBack(event_json, j.GetAllocator()); + } + + j["traces"]["events"] = value; + return j; +} + +void FileQLogger::toJsonBase(Document& j, Document& traces) { + + Document::AllocatorType& j_allocator = j.GetAllocator(); + Document::AllocatorType& traces_allocator = traces.GetAllocator(); + + j.AddMember(Value(kQLogDescriptionField, j_allocator).Move(), + Value(kQLogDescription, j_allocator).Move(), + j_allocator); + + j.AddMember(Value(kQLogVersionField, j_allocator).Move(), + Value(kQLogVersion, j_allocator).Move(), + j_allocator); + + j.AddMember(Value(kQLogTitleField, j_allocator).Move(), + Value(kQLogTitle, j_allocator).Move(), + j_allocator); + Value event_fields; + j.AddMember("qlog_format", "JSON", j_allocator); + + Value value; + value.SetArray(); + j.AddMember("traces", value, j_allocator); + + // trace[common_fields] + std::string dcidStr = (!dcid_ || dcid_->IsEmpty()) ? "" : dcid_->ToString(); + std::string scidStr = (!scid_ || scid_->IsEmpty()) ? "" : scid_->ToString(); + value.SetObject(); + value.AddMember( "dcid", Value(dcidStr.c_str(), traces_allocator).Move(), traces_allocator); + value.AddMember( "protocol_type", Value(protocolType_.c_str(), traces_allocator).Move(), traces_allocator); + value.AddMember( "reference_time", system_startTime_.count(), traces_allocator); + value.AddMember( "scid", Value(scidStr.c_str(), traces_allocator).Move(), traces_allocator); + traces.AddMember("common_fields", value, traces_allocator); + + event_fields.SetArray(); + event_fields.PushBack("relative_time", traces_allocator); + event_fields.PushBack("category", traces_allocator); + event_fields.PushBack("event", traces_allocator); + event_fields.PushBack("data", traces_allocator); + traces.AddMember("event_fields", event_fields, traces_allocator); + + value.SetObject(); + value.AddMember( "time_offset", 0, traces_allocator); + value.AddMember( "time_units", Value(kQLogTimeUnits, traces_allocator).Move(), traces_allocator); + traces.AddMember("configuration", value, traces_allocator); + + traces.AddMember("description", Value(kQLogTraceDescription, traces_allocator).Move(), traces_allocator); + + value.SetArray(); + traces.AddMember("events", value, traces_allocator); + traces.AddMember("title", Value(kQLogTraceTitle, traces_allocator).Move(), traces_allocator); + + char host[100] = {0}; + gethostname(host, sizeof(host)); + + value.SetObject(); + value.AddMember( "type", Value(vantagePointString(vantagePoint_).data(), traces_allocator).Move(), traces_allocator); + value.AddMember( "name", Value(host, traces_allocator).Move(), traces_allocator); + traces.AddMember("vantage_point", value, traces_allocator); +} + +void FileQLogger::addSummary(Value& value, Document::AllocatorType& summary_allocator) { + value.AddMember("total_bytes_sent", summary_.totalBytesSent, summary_allocator); + value.AddMember("total_packets_sent", summary_.totalPacketsSent, summary_allocator); + value.AddMember("total_bytes_recvd", summary_.totalBytesRecvd, summary_allocator); + value.AddMember("total_packets_recvd", summary_.totalPacketsRecvd, summary_allocator); + value.AddMember("total_packets_lost", summary_.totalPacketsLost, summary_allocator); + value.AddMember("quic_transport_version", Value(QuicVersionToString(summary_.quicVersion).c_str(), summary_allocator).Move(), summary_allocator); + value.AddMember("connection_duration", ((int)(connection_duration_ * 100 + 0.5)) / 100.0, summary_allocator); +} + +Document FileQLogger::generateSummary( + size_t numEvents, + std::chrono::microseconds endTime) { + Document summaryObj; + summaryObj.SetObject(); + Document::AllocatorType& summary_allocator = summaryObj.GetAllocator(); + + Value key, value; + key.SetString(StringRef(kQLogTraceCountField)); + value.SetInt(1); + summaryObj.AddMember(key, value, summary_allocator); // hardcoded, we only support 1 trace right now + + // is calculated like this : if there is <= 1 event, summary [max_duration] is 0 + // otherwise, it is the (time of the last event - time of the first event) + key.SetString(StringRef("max_duration")); + value.SetInt64((numEvents == 0) ? 0 : endTime.count()); + summaryObj.AddMember(key, value, summary_allocator); + + key.SetString(StringRef("total_event_count")); + value.SetInt64(numEvents); + summaryObj.AddMember(key, value, summary_allocator); + + // summaryObj [report_summary] + value.SetObject(); + addSummary(value, summary_allocator); + summaryObj.AddMember("report_summary", value, summary_allocator); + + return summaryObj; +} + +void FileQLogger::addStreamStateUpdate( + quic::QuicStreamId id, + std::string& update, + quiche::QuicheOptionalImpl timeSinceStreamCreation) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + id, + update, + std::move(timeSinceStreamCreation), + vantagePoint_, + refTime)); +} + +void FileQLogger::addConnectionMigrationUpdate(bool intentionalMigration) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + intentionalMigration, vantagePoint_, refTime)); +} + +void FileQLogger::addPathValidationEvent(bool success) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + success, vantagePoint_, refTime)); +} + +void FileQLogger::addPriorityUpdate( + quic::QuicStreamId streamId, + uint8_t urgency, + bool incremental) { + auto refTime = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) - steady_startTime_; + endTime_ = refTime; + handleEvent(std::make_unique( + streamId, urgency, incremental, refTime)); +} + +void FileQLogger::outputLogsToFile(const std::string& path, bool prettyJson) { + if (streaming_) { + return; + } + if (dcid_->IsEmpty()) { + QUIC_LOG(ERROR) << "Error: No dcid_ found"; + return; + } + + std::string outputPath; + absl::StrAppend(&outputPath, path, "/", dcid_->ToString(), ".qlog"); + std::ofstream fileObj; + fileObj.open(outputPath, std::fstream::out); + + if (fileObj) { + Document qLog = toJson(); + StringBuffer buffer; + Writer writer(buffer); + qLog.Accept(writer); + std::string base_Json = buffer.GetString(); + if (prettyJson) { + fileObj << std::setw(4) << base_Json; + } else { + fileObj << base_Json; + } + } else { + QUIC_LOG(ERROR) << "Error: Can't write to provided path: " << path; + } + fileObj.close(); +} + +} // namespace quic diff --git a/base/bvc-qlog/src/file_qlogger.h b/base/bvc-qlog/src/file_qlogger.h new file mode 100644 index 00000000..c1b1d275 --- /dev/null +++ b/base/bvc-qlog/src/file_qlogger.h @@ -0,0 +1,195 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#pragma once +#include +#include +#include +#include +#include + +#include "base/bvc-qlog/src/base_qlogger.h" +#include "base/bvc-qlog/src/qlogger_constants.h" +#include "base/bvc-qlog/src/qlogger_types.h" +#include "base/sinks/sequence_file_sink.h" + +#include "gquiche/quic/core/quic_stream.h" +#include "gquiche/quic/core/quic_packets.h" +#include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/core/quic_connection_id.h" +#include "spdlog/async.h" +#include "spdlog/sinks/basic_file_sink.h" +#include "rapidjson/document.h" +#include "rapidjson/writer.h" +#include "rapidjson/stringbuffer.h" + +using namespace rapidjson; + +namespace quic { + +class FileQLogger : public BaseQLogger { + public: + using QLogger::TransportSummaryArgs; + std::vector> logs; + FileQLogger( + VantagePoint vantagePointIn, + std::string& path, + std::shared_ptr tp, + std::size_t log_event_buffer = 0, + std::string protocolTypeIn = kHTTP3ProtocolType, + bool prettyJson = false, + bool streaming = true) + : BaseQLogger(vantagePointIn, std::move(protocolTypeIn)), + path_(path), + tp_(tp), + log_event_buffer_(log_event_buffer), + prettyJson_(prettyJson), + streaming_(streaming) {} + + ~FileQLogger() override { + if (streaming_ && !dcid_->IsEmpty()) { + finishStream(); + } + } + + // retry packet (client) + void addPacket(const std::string& newConnectionId, uint64_t packetSize, bool isPacketRecvd) {} + // public reset packet + void addPacket(const QuicPublicResetPacket& publicResetPacket, uint64_t packetSize, bool isPacketRecvd); + // version negotiation packet + void addPacket( + const QuicVersionNegotiationPacket& versionNegotiationPacket, + uint64_t packetSize, + bool isPacketRecvd); + // ietf stateless reset packet (client) + void addPacket(const QuicIetfStatelessResetPacket& ietfStatelessResetPacket, uint64_t packetSize, bool isPacketRecvd) {} + void addPacketFrame( + QLogPacketEvent* event, + QuicFrameType frame_type, + void* frame, + bool isPacketRecvd); + // data packet received + std::unique_ptr createPacketEvent( + const QuicPacketHeader& packetHeader, + uint64_t packetSize, + bool isPacketRecvd); + void finishCreatePacketEvent(std::unique_ptr event); + // serialized packet to be sent + void addPacket( + uint64_t packet_number, + uint64_t packet_length, + TransmissionType transmission_type, + EncryptionLevel encryption_level, + const QuicFrames& retransmittable_frames, + const QuicFrames& nonretransmittable_frames, + bool isPacketRecvd); + + void addConnectionClose( + QuicErrorCode error, + const std::string& reason, + ConnectionCloseSource source) override; + void addPacingMetricUpdate( + uint64_t pacingBurstSizeIn, + std::chrono::microseconds pacingIntervalIn) override; + void addPacingObservation( + std::string& actual, + std::string& expected, + std::string& conclusion) override; + void addBandwidthEstUpdate(uint64_t bytes, std::chrono::microseconds interval) + override; + void addAppLimitedUpdate() override; + void addAppUnlimitedUpdate() override; + void addAppIdleUpdate(std::string& idleEvent, bool idle) override; + void addPacketDrop(size_t packetSize, std::string& dropReasonIn) override; + void addDatagramReceived(uint64_t dataLen) override; + void addLossAlarm( + uint64_t largestSent, + uint64_t alarmCount, + uint64_t outstandingPackets, + std::string& type) override; + void addPacketLost( + uint64_t LostPacketNum, + EncryptionLevel level, + TransmissionType type) override; + void addTransportStateUpdate(std::string& update) override; + void addPacketBuffered( + uint64_t packetNum, + EncryptionLevel protectionType, + uint64_t packetSize) override; + void addMetricUpdate( + std::chrono::microseconds latestRtt, + std::chrono::microseconds mrtt, + std::chrono::microseconds srtt, + std::chrono::microseconds ackDelay) override; + void addStreamStateUpdate( + QuicStreamId id, + std::string& update, + quiche::QuicheOptionalImpl timeSinceStreamCreation) + override; + virtual void addConnectionMigrationUpdate(bool intentionalMigration) override; + virtual void addPathValidationEvent(bool success) override; + void addPriorityUpdate( + quic::QuicStreamId streamId, + uint8_t urgency, + bool incremental) override; + + void outputLogsToFile(const std::string& path, bool prettyJson); + Document toJson(); + void toJsonBase(Document& j, Document& trace); + Document generateSummary( + size_t numEvents, + std::chrono::microseconds endTime); + + void setDcid(quiche::QuicheOptionalImpl connID) override; + void setScid(quiche::QuicheOptionalImpl connID) override; + void setQuicVersion(const QuicTransportVersion version) override; + void usedZeroRtt(bool use) override; + + void addSummary(Value& value, Document::AllocatorType& summary_allocator); + + void initialSummary(); + void switchSpdlogObject(const std::string& tmp_path, const std::string& final_path, uint64_t switch_qlog_index); + + private: + void createBaseJson(); + void setFileObject(); + void setSpdlogObject(); + void setupStream(); + void finishStream(); + void handleEvent(std::unique_ptr event); + + std::string path_; + std::string basePadding_ = " "; + std::string eventsPadding_ = ""; + std::string eventLine_; + std::string token_; + std::string endLine_; + std::string metadata_head_; + std::stringstream baseJson_; + std::string metadata_head_extra_; + StringBuffer buffer_; + Writer writer_; + std::string logstring_; + + std::string error_; + std::string reason_; + std::string source_; + + std::ofstream fileObj_; + std::shared_ptr logger_; + std::shared_ptr tp_; + + std::size_t log_event_buffer_; + + uint64_t connection_duration_; + bool prettyJson_; + bool streaming_; + int numEvents_ = 0; + size_t pos_; +}; +} // namespace quic diff --git a/base/bvc-qlog/src/qlogger.cc b/base/bvc-qlog/src/qlogger.cc new file mode 100644 index 00000000..428c1772 --- /dev/null +++ b/base/bvc-qlog/src/qlogger.cc @@ -0,0 +1,37 @@ +#include "base/bvc-qlog/src/qlogger.h" + +#include "gquiche/quic/core/quic_stream.h" +#include "gquiche/quic/core/quic_packets.h" +#include "platform/quiche_platform_impl/quiche_text_utils_impl.h" + +namespace quic { + +std::string getFlowControlEvent(int offset) { + return "flow control event, new offset: " + std::to_string(offset); +} + +std::string getRxStreamWU(QuicStreamId streamId, uint64_t packetNum, uint64_t maximumData) { + return "rx stream, streamId: " + quiche::QuicheTextUtilsImpl::Uint64ToString(streamId) + + ", packetNum: " + quiche::QuicheTextUtilsImpl::Uint64ToString(packetNum) + + ", maximumData: " + quiche::QuicheTextUtilsImpl::Uint64ToString(maximumData); +} + +std::string getRxConnWU(uint64_t packetNum, uint64_t maximumData) { + return "rx, packetNum: " + quiche::QuicheTextUtilsImpl::Uint64ToString(packetNum) + + ", maximumData: " + quiche::QuicheTextUtilsImpl::Uint64ToString(maximumData); +} + +std::string getPeerClose(const std::string& peerCloseReason) { + return "error message: " + peerCloseReason; +} + +std::string getFlowControlWindowAvailable(uint64_t windowAvailable) { + return "on flow control, window available: " + + quiche::QuicheTextUtilsImpl::Uint64ToString(windowAvailable); +} + +std::string getClosingStream(const std::string& streamId) { + return "closing stream, stream id: " + streamId; +} + +} // namespace quic diff --git a/base/bvc-qlog/src/qlogger.h b/base/bvc-qlog/src/qlogger.h new file mode 100644 index 00000000..6f187768 --- /dev/null +++ b/base/bvc-qlog/src/qlogger.h @@ -0,0 +1,144 @@ +#pragma once +#include + +#include "base/bvc-qlog/src/qlogger_constants.h" +#include "gquiche/quic/core/quic_stream.h" +#include "gquiche/quic/core/quic_connection_id.h" +#include "gquiche/quic/core/quic_packets.h" +#include "gquiche/quic/core/quic_versions.h" +#include "gquiche/quic/core/quic_types.h" +#include "platform/quiche_platform_impl/quiche_optional_impl.h" + +namespace quic { + +struct PacingObserver { + PacingObserver() = default; + virtual ~PacingObserver() = default; + virtual void onNewPacingRate( + uint64_t packetsPerInterval, + std::chrono::microseconds interval) = 0; + virtual void onPacketSent() = 0; +}; + +class QLogger { + public: + explicit QLogger(VantagePoint vantagePointIn, std::string protocolTypeIn) + : vantagePoint_(vantagePointIn), protocolType_(std::move(protocolTypeIn)) {} + + quiche::QuicheOptionalImpl dcid_; + quiche::QuicheOptionalImpl scid_; + VantagePoint vantagePoint_; + std::string protocolType_; + QLogger() = delete; + virtual ~QLogger() = default; + virtual void addPacket( + const std::string& newConnectionId, + uint64_t packetSize, + bool isPacketRecvd) = 0; + virtual void addPacket( + const QuicPublicResetPacket& publicResetPacket, + uint64_t packetSize, + bool isPacketRecvd) = 0; + virtual void addPacket( + const QuicVersionNegotiationPacket& versionNegotiationPacket, + uint64_t packetSize, + bool isPacketRecvd) = 0; + virtual void addConnectionClose( + QuicErrorCode error, + const std::string& reason, + ConnectionCloseSource source) = 0; + struct TransportSummaryArgs { + uint64_t totalBytesSent{}; + uint64_t totalPacketsSent{}; + uint64_t totalBytesRecvd{}; + uint64_t totalPacketsRecvd{}; + uint64_t sumCurWriteOffset{}; + uint64_t sumMaxObservedOffset{}; + uint64_t sumCurStreamBufferLen{}; + uint64_t totalPacketsLost{}; + uint64_t totalStartupDuration{}; + uint64_t totalDrainDuration{}; + uint64_t totalProbeBWDuration{}; + uint64_t totalProbeRttDuration{}; + uint64_t totalNotRecoveryDuration{}; + uint64_t totalGrowthDuration{}; + uint64_t totalConservationDuration{}; + uint64_t totalStreamBytesCloned{}; + uint64_t totalBytesCloned{}; + uint64_t totalCryptoDataWritten{}; + uint64_t totalCryptoDataRecvd{}; + uint64_t currentWritableBytes{}; + uint64_t currentConnFlowControl{}; + double smoothedMaxBandwidth{}; + double smoothedMinRtt{}; + double smoothedMeanDeviation{}; + bool usedZeroRtt{false}; + QuicTransportVersion quicVersion{}; + CongestionControlType congestionControl{}; + }; + + virtual void addBandwidthEstUpdate( + uint64_t bytes, + std::chrono::microseconds interval) = 0; + virtual void addAppLimitedUpdate() = 0; + virtual void addAppUnlimitedUpdate() = 0; + virtual void addPacingMetricUpdate( + uint64_t pacingBurstSizeIn, + std::chrono::microseconds pacingIntervalIn) = 0; + virtual void addPacingObservation( + std::string& actual, + std::string& expected, + std::string& conclusion) = 0; + virtual void addAppIdleUpdate(std::string& idleEvent, bool idle) = 0; + virtual void addPacketDrop(size_t packetSize, std::string& dropReasonIn) = 0; + virtual void addDatagramReceived(uint64_t dataLen) = 0; + virtual void addLossAlarm( + uint64_t largestSent, + uint64_t alarmCount, + uint64_t outstandingPackets, + std::string& type) = 0; + virtual void addPacketLost( + uint64_t LostPacketNum, + EncryptionLevel level, + TransmissionType type) = 0; + virtual void addTransportStateUpdate(std::string& update) = 0; + virtual void addPacketBuffered( + uint64_t packetNum, + EncryptionLevel protectionType, + uint64_t packetSize) = 0; + virtual void addMetricUpdate( + std::chrono::microseconds latestRtt, + std::chrono::microseconds mrtt, + std::chrono::microseconds srtt, + std::chrono::microseconds ackDelay) = 0; + virtual void addStreamStateUpdate( + quic::QuicStreamId streamId, + std::string& update, + quiche::QuicheOptionalImpl timeSinceStreamCreation) = 0; + virtual void addConnectionMigrationUpdate(bool intentionalMigration) = 0; + virtual void addPathValidationEvent(bool success) = 0; + virtual void addPriorityUpdate( + quic::QuicStreamId streamId, + uint8_t urgency, + bool incremental) = 0; + + virtual void setDcid(quiche::QuicheOptionalImpl connID) = 0; + virtual void setScid(quiche::QuicheOptionalImpl connID) = 0; + virtual void setQuicVersion(const QuicTransportVersion version) = 0; + virtual void usedZeroRtt(bool use) = 0; + TransportSummaryArgs summary_; +}; + +std::string getFlowControlEvent(int offset); + +std::string getRxStreamWU(QuicStreamId streamId, uint64_t packetNum, uint64_t maximumData); + +std::string getRxConnWU(uint64_t packetNum, uint64_t maximumData); + +std::string getPeerClose(const std::string& errMsg); + +std::string getFlowControlWindowAvailable(uint64_t windowAvailable); + +std::string getClosingStream(const std::string& streamId); + +} // namespace quic diff --git a/base/bvc-qlog/src/qlogger_constants.cc b/base/bvc-qlog/src/qlogger_constants.cc new file mode 100644 index 00000000..c96184e9 --- /dev/null +++ b/base/bvc-qlog/src/qlogger_constants.cc @@ -0,0 +1,218 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + * + */ + +#include "base/bvc-qlog/src/qlogger_constants.h" + +namespace quic { +quiche::QuicheStringPiece vantagePointString(VantagePoint vantagePoint) { + switch (vantagePoint) { + case VantagePoint::IS_CLIENT: + return kQLogClientVantagePoint; + case VantagePoint::IS_SERVER: + return kQLogServerVantagePoint; + default: + return "unknown_perspective"; + } +} + +quiche::QuicheStringPiece toQlogString(QuicFrameType frame) { + switch (frame) { + case QuicFrameType::PADDING_FRAME: + return "padding"; + case QuicFrameType::RST_STREAM_FRAME: + return "rst_stream"; + case QuicFrameType::CONNECTION_CLOSE_FRAME: + return "connection_close"; + case QuicFrameType::GOAWAY_FRAME: + return "go_away"; + case QuicFrameType::WINDOW_UPDATE_FRAME: + return "window_update"; + case QuicFrameType::BLOCKED_FRAME: + return "blocked"; + case QuicFrameType::STOP_WAITING_FRAME: + return "stop_waiting"; + case QuicFrameType::PING_FRAME: + return "ping"; + case QuicFrameType::ACK_FRAME: + return "ack"; + case QuicFrameType::STREAM_FRAME: + return "stream"; + case QuicFrameType::CRYPTO_FRAME: + return "crypto"; + case QuicFrameType::HANDSHAKE_DONE_FRAME: + return "handshake_done"; + case QuicFrameType::MTU_DISCOVERY_FRAME: + return "mtu_discovery"; + case QuicFrameType::NEW_CONNECTION_ID_FRAME: + return "new_connection_id"; + case QuicFrameType::MAX_STREAMS_FRAME: + return "max_streams"; + case QuicFrameType::STREAMS_BLOCKED_FRAME: + return "streams_blocked"; + case QuicFrameType::PATH_RESPONSE_FRAME: + return "path_response"; + case QuicFrameType::PATH_CHALLENGE_FRAME: + return "path_challenge"; + case QuicFrameType::STOP_SENDING_FRAME: + return "stop_sending"; + case QuicFrameType::MESSAGE_FRAME: + return "message"; + case QuicFrameType::NEW_TOKEN_FRAME: + return "new_token"; + case QuicFrameType::RETIRE_CONNECTION_ID_FRAME: + return "retire_connection_id"; + case QuicFrameType::ACK_FREQUENCY_FRAME: + return "ack_frequency"; + default: + return "unknown_frame"; + } +} + +quiche::QuicheStringPiece toQlogString(QuicLongHeaderType type) { + switch (type) { + case QuicLongHeaderType::INITIAL: + return "initial"; + case QuicLongHeaderType::RETRY: + return "RETRY"; + case QuicLongHeaderType::HANDSHAKE: + return "handshake"; + case QuicLongHeaderType::ZERO_RTT_PROTECTED: + return "0RTT"; + case QuicLongHeaderType::VERSION_NEGOTIATION: + return "version_negotiation"; + case QuicLongHeaderType::INVALID_PACKET_TYPE: + return "invalid"; + default: + return "unknown_header_type"; + } +} + +quiche::QuicheStringPiece toQlogString(EncryptionLevel level) { + switch (level) { + case EncryptionLevel::ENCRYPTION_INITIAL: + return "initial"; + case EncryptionLevel::ENCRYPTION_HANDSHAKE: + return "handshake"; + case EncryptionLevel::ENCRYPTION_ZERO_RTT: + return "zero_rtt"; + case EncryptionLevel::ENCRYPTION_FORWARD_SECURE: + return "forward_secure"; + case EncryptionLevel::NUM_ENCRYPTION_LEVELS: + default: + return "invalid_encryption_level"; + } +} + +QuicLongHeaderType encryptionLevelToLongHeaderType(EncryptionLevel level) { + switch (level) { + case EncryptionLevel::ENCRYPTION_INITIAL: + return QuicLongHeaderType::INITIAL; + case EncryptionLevel::ENCRYPTION_HANDSHAKE: + return QuicLongHeaderType::HANDSHAKE; + case EncryptionLevel::ENCRYPTION_ZERO_RTT: + return QuicLongHeaderType::ZERO_RTT_PROTECTED; + case EncryptionLevel::ENCRYPTION_FORWARD_SECURE: + default: + return QuicLongHeaderType::INVALID_PACKET_TYPE; + } +} + +quiche::QuicheStringPiece toQlogString(TransmissionType type) { + switch (type) { + case TransmissionType::NOT_RETRANSMISSION: + return "not_retransmission"; + case TransmissionType::HANDSHAKE_RETRANSMISSION: + return "handshake_retransmission"; + case TransmissionType::ALL_ZERO_RTT_RETRANSMISSION: + return "all_zero_rtt_retransmission"; + case TransmissionType::LOSS_RETRANSMISSION: + return "loss_retransmission"; + case TransmissionType::RTO_RETRANSMISSION: + return "rto_retransmission"; + case TransmissionType::TLP_RETRANSMISSION: + return "tlp_retransmission"; + case TransmissionType::PTO_RETRANSMISSION: + return "pto_retransmission"; + case TransmissionType::PROBING_RETRANSMISSION: + return "probing_retransmission"; + default: + return "invalid_retransmission"; + } +} + +quiche::QuicheStringPiece toQlogString(CongestionControlType type) { + switch (type) { + case kCubicBytes: + return "Cubic"; + case kRenoBytes: + return "Reno"; + case kBBR: + return "BBR"; + case kPCC: + return "PCC"; + case kGoogCC: + return "GoogCC"; + case kBBRv2: + return "BBRv2"; + } + return "invalid_type"; +} + +quiche::QuicheStringPiece toQlogString(BbrSender::Mode mode) { + switch (mode) { + case BbrSender::STARTUP: + return "STARTUP"; + case BbrSender::DRAIN: + return "DRAIN"; + case BbrSender::PROBE_BW: + return "PROBE_BW"; + case BbrSender::PROBE_RTT: + return "PROBE_RTT"; + } + return "invalid_mode"; +} + +quiche::QuicheStringPiece toQlogString(Bbr2Mode mode) { + switch (mode) { + case Bbr2Mode::STARTUP: + return "STARTUP"; + case Bbr2Mode::DRAIN: + return "DRAIN"; + case Bbr2Mode::PROBE_BW: + return "PROBE_BW"; + case Bbr2Mode::PROBE_RTT: + return "PROBE_RTT"; + } + return "InvalidMode"; +} + +quiche::QuicheStringPiece toQlogString(BbrSender::RecoveryState recovery_state) { + switch (recovery_state) { + case BbrSender::NOT_IN_RECOVERY: + return "NOT_IN_RECOVERY"; + case BbrSender::CONSERVATION: + return "CONSERVATION"; + case BbrSender::GROWTH: + return "GROWTH"; + } + return "invalid_state"; +} + +quiche::QuicheStringPiece toQlogString(QuicConnectionCloseType close_type) { + switch (close_type) { + case GOOGLE_QUIC_CONNECTION_CLOSE: + return "GOOGLE_QUIC_CONNECTION_CLOSE"; + case IETF_QUIC_TRANSPORT_CONNECTION_CLOSE: + return "IETF_QUIC_TRANSPORT_CONNECTION_CLOSE"; + case IETF_QUIC_APPLICATION_CONNECTION_CLOSE: + return "IETF_APPLICATION_CONNECTION_CLOSE"; + } + return "invalid_type"; +} + +}// namespace quic diff --git a/base/bvc-qlog/src/qlogger_constants.h b/base/bvc-qlog/src/qlogger_constants.h new file mode 100644 index 00000000..7a04b300 --- /dev/null +++ b/base/bvc-qlog/src/qlogger_constants.h @@ -0,0 +1,116 @@ +#pragma once + +#include "platform/quiche_platform_impl/quiche_text_utils_impl.h" +#include "gquiche/quic/core/quic_types.h" +#include "platform/quiche_platform_impl/quiche_optional_impl.h" +#include "gquiche/quic/core/congestion_control/bbr_sender.h" +#include "gquiche/quic/core/congestion_control/bbr2_sender.h" +#include "gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h" + +namespace quic { +constexpr quiche::QuicheStringPiece kShortHeaderPacketType = "1RTT"; +constexpr quiche::QuicheStringPiece kGooglePacketType = "GQUIC"; +constexpr uint8_t kQuicFrameTypeBrokenMask = 0xE0; +constexpr uint8_t kQuicFrameTypeSpecialMask = 0xC0; +constexpr uint8_t kQuicFrameTypeStreamMask = 0x80; +constexpr uint8_t kQuicFrameTypeAckMask = 0x40; +constexpr auto kVersionNegotiationPacketType = "version_negotiation"; +constexpr auto kQuicPublicResetPacketType = "quic_public_reset"; +constexpr auto kQuicIetfStatelessResetPacketType = "ietf_stateless_reset"; +constexpr auto kHTTP3ProtocolType = "QUIC_HTTP3"; +constexpr auto kNoError = "no error"; +constexpr auto kGracefulExit = "graceful exit"; +constexpr auto kPersistentCongestion = "persistent congestion"; +constexpr auto kRemoveInflight = "remove bytes in flight"; +constexpr auto kCubicSkipLoss = "cubic skip loss"; +constexpr auto kCubicLoss = "cubic loss"; +constexpr auto kCubicSteadyCwnd = "cubic steady cwnd"; +constexpr auto kCubicSkipAck = "cubic skip ack"; +constexpr auto kCubicInit = "cubic init"; +constexpr auto kCongestionPacketAck = "congestion packet ack"; +constexpr auto kCwndNoChange = "cwnd no change"; +constexpr auto kAckInQuiescence = "ack in quiescence"; +constexpr auto kResetTimeToOrigin = "reset time to origin"; +constexpr auto kResetLastReductionTime = "reset last reduction time"; +constexpr auto kRenoCwndEstimation = "reno cwnd estimation"; +constexpr auto kPacketAckedInRecovery = "packet acked in recovery"; +constexpr auto kCopaInit = "copa init"; +constexpr auto kCongestionPacketSent = "congestion on packet sent"; +constexpr auto kCopaCheckAndUpdateDirection = "copa check and update direction"; +constexpr auto kCongestionPacketLoss = "congestion packet loss"; +constexpr auto kAppLimited = "app limited"; +constexpr auto kAppUnlimited = "app unlimited"; +constexpr uint64_t kDefaultCwnd = 12320; +constexpr auto kAppIdle = "app idle"; +constexpr auto kMaxBuffered = "max buffered"; +constexpr auto kCipherUnavailable = "cipher unavailable"; +constexpr auto kParse = "parse"; +constexpr auto kNonRegular = "non regular"; +constexpr auto kAlreadyClosed = "already closed"; +constexpr auto kUdpTruncated = "udp truncated"; +constexpr auto kNoData = "no data"; +constexpr auto kUnexpectedProtectionLevel = "unexpected protection level"; +constexpr auto kBufferUnavailable = "buffer unavailable"; +constexpr auto kReset = "reset"; +constexpr auto kRetry = "retry"; +constexpr auto kPtoAlarm = "pto alarm"; +constexpr auto kHandshakeAlarm = "handshake alarm"; +constexpr auto kLossTimeoutExpired = "loss timeout expired"; +constexpr auto kStart = "start"; +constexpr auto kWriteNst = "write nst"; +constexpr auto kTransportReady = "transport ready"; +constexpr auto kDerivedZeroRttReadCipher = "derived 0-rtt read cipher"; +constexpr auto kDerivedOneRttReadCipher = "derived 1-rtt read cipher"; +constexpr auto kDerivedOneRttWriteCipher = "derived 1-rtt write cipher"; +constexpr auto kZeroRttRejected = "zerortt rejected"; +constexpr auto kZeroRttAccepted = "zerortt accepted"; +constexpr auto kZeroRttAttempted = "zerortt attempted"; +constexpr auto kRecalculateTimeToOrigin = "recalculate time to origin"; +constexpr auto kAbort = "abort"; +constexpr auto kQLogVersion = "draft-00"; +constexpr auto kQLogTitle = "bvc-qlog"; +constexpr auto kQLogDescription = "Converted from file"; +constexpr auto kQLogTraceTitle = "bvc-qlog from single connection"; +constexpr auto kQLogTraceDescription = "Generated qlog from connection"; +constexpr auto kQLogTimeUnits = "us"; +constexpr auto kQLogVersionField = "qlog_version"; +constexpr auto kQLogTitleField = "title"; +constexpr auto kQLogDescriptionField = "description"; +constexpr auto kQLogTraceCountField = "trace_count"; +constexpr auto kEOM = "eom"; +constexpr auto kOnEOM = "on eom"; +constexpr auto kStreamBlocked = "stream blocked"; +constexpr auto kHeaders = "headers"; +constexpr auto kOnHeaders = "on headers"; +constexpr auto kOnError = "on error"; +constexpr auto kPushPromise = "push promise"; +constexpr auto kBody = "body"; + +constexpr quiche::QuicheStringPiece kQLogServerVantagePoint = "server"; +constexpr quiche::QuicheStringPiece kQLogClientVantagePoint = "client"; + +using VantagePoint = Perspective; + +quiche::QuicheStringPiece vantagePointString(VantagePoint vantagePoint); + +quiche::QuicheStringPiece toQlogString(QuicFrameType frame); + +quiche::QuicheStringPiece toQlogString(QuicLongHeaderType type); + +quiche::QuicheStringPiece toQlogString(EncryptionLevel level); + +QuicLongHeaderType encryptionLevelToLongHeaderType(EncryptionLevel level); + +quiche::QuicheStringPiece toQlogString(TransmissionType type); + +quiche::QuicheStringPiece toQlogString(CongestionControlType type); + +quiche::QuicheStringPiece toQlogString(BbrSender::Mode mode); + +quiche::QuicheStringPiece toQlogString(Bbr2Mode mode); + +quiche::QuicheStringPiece toQlogString(BbrSender::RecoveryState recovery_state); + +quiche::QuicheStringPiece toQlogString(QuicConnectionCloseType close_type); + +} // namespace quic diff --git a/base/bvc-qlog/src/qlogger_types.cc b/base/bvc-qlog/src/qlogger_types.cc new file mode 100644 index 00000000..9d883dad --- /dev/null +++ b/base/bvc-qlog/src/qlogger_types.cc @@ -0,0 +1,1702 @@ +#include "base/bvc-qlog/src/qlogger_types.h" +#include "base/bvc-qlog/src/qlogger_constants.h" +#include "platform/quiche_platform_impl/quiche_text_utils_impl.h" +#include "gquiche/quic/core/quic_error_codes.h" + +namespace quic { + +Document QLogFrame::toShortJson() const { + Document j; + return j; +} + +Document PaddingFrameLog::toJson() const { + Document j; + j.SetObject(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::PADDING_FRAME).data(), j.GetAllocator()).Move(), + j.GetAllocator()); + return j; +} + +// TODO: padding is not showing up in short form, need to fix +Document PaddingFrameLog::toShortJson() const { + Document j; + j.SetArray(); + j.PushBack("padding", j.GetAllocator()); + return j; +} + +Document RstStreamFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::RST_STREAM_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("stream_id", + streamId, + j_allocator); + j.AddMember("error_code", + errorCode, + j_allocator); + j.AddMember("offset", + offset, + j_allocator); + return j; +} + +Document ConnectionCloseFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::CONNECTION_CLOSE_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("wire_error_code", + wireErrorCode, + j_allocator); + j.AddMember("quic_error_code", + Value(QuicErrorCodeToString(quicErrorCode), j_allocator).Move(), + j_allocator); + j.AddMember("error_details", + Value(errorDetails.c_str(), j_allocator).Move(), + j_allocator); + j.AddMember("close_type", + Value((toQlogString(closeType)).data(), j_allocator).Move(), + j_allocator); + j.AddMember("transport_closing_frame_type", + transportCloseFrameType, + j_allocator); + return j; +} + +Document GoAwayFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::GOAWAY_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("error_code", + Value(QuicErrorCodeToString(errorCode), j_allocator).Move(), + j_allocator); + j.AddMember("reason_phrase", + Value(reasonPhrase.c_str(), j_allocator).Move(), + j_allocator); + j.AddMember("last_good_stream_id", + lastGoodStreamId, + j_allocator); + return j; +} + +Document WindowUpdateFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::WINDOW_UPDATE_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("stream_id", + streamId, + j_allocator); + j.AddMember("max_data", + maxData, + j_allocator); + return j; +} + +Document BlockedFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::BLOCKED_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("stream_id", + streamId, + j_allocator); + return j; +} + +Document StopWaitingFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::STOP_WAITING_FRAME).data(), j_allocator).Move(), + j_allocator); + return j; +} + +Document PingFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::PING_FRAME).data(), j_allocator).Move(), + j_allocator); + return j; +} + +Document AckFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetArray(); + + Value temp_value; + for (auto interval : packetNumberQueue) { + temp_value.SetArray(); + //*: Because the interval uses half-closed range `[)` and causes confusion, + //*: minus 1 here to make it closed range '[]' + if(interval.max().ToUint64() - 1 == interval.min().ToUint64()) { + temp_value.PushBack(interval.min().ToUint64(), j_allocator); + } else { + temp_value.PushBack(interval.min().ToUint64(), j_allocator); + temp_value.PushBack(interval.max().ToUint64() - 1, j_allocator); + } + value.PushBack(temp_value, j_allocator); + } + + j.AddMember("acked_ranges", + value, + j_allocator); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::ACK_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("ack_delay", + ackDelay.count(), + j_allocator); + return j; +} + +Document AckFrameLog::toShortJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetArray(); + + Value tmp_arr; + for (auto interval : packetNumberQueue) { + tmp_arr.SetArray(); + if(interval.max().ToUint64() - 1 == interval.min().ToUint64()) { + tmp_arr.PushBack(interval.min().ToUint64(), j_allocator); + } else { + tmp_arr.PushBack(interval.min().ToUint64(), j_allocator); + tmp_arr.PushBack(interval.max().ToUint64() - 1, j_allocator); + } + value.PushBack(tmp_arr, j_allocator); + } + + j.PushBack(value, j_allocator); + j.PushBack(ackDelay.count(), j_allocator); + return j; +} + +Document StreamFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::STREAM_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("stream_id", + streamId, + j_allocator); + j.AddMember("offset", + offset, + j_allocator); + j.AddMember("length", + len, + j_allocator); + j.AddMember("fin", + fin, + j_allocator); + return j; +} + +Document StreamFrameLog::toShortJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.PushBack(streamId, j_allocator); + j.PushBack(offset, j_allocator); + j.PushBack(len, j_allocator); + j.PushBack(fin, j_allocator); + return j; +} + +Document CryptoFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("encryption_level", + Value(toQlogString(level).data(), j_allocator).Move(), + j_allocator); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::CRYPTO_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("offset", + offset, + j_allocator); + j.AddMember("data_length", + dataLength, + j_allocator); + return j; +} + +Document HandshakeDoneFrameLog::toJson() const { + Document j; + j.SetObject(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::HANDSHAKE_DONE_FRAME).data(), j.GetAllocator()).Move(), + j.GetAllocator()); + return j; +} + +Document MTUDiscoveryFrameLog::toJson() const { + Document j; + j.SetObject(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::MTU_DISCOVERY_FRAME).data(), j.GetAllocator()).Move(), + j.GetAllocator()); + return j; +} + +Document NewConnectionIdFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::NEW_CONNECTION_ID_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("sequence", + sequenceNumber, + j_allocator); + return j; +} + +Document MaxStreamsFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::MAX_STREAMS_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("max_streams", + streamCount, + j_allocator); + j.AddMember("direction", + Value(unidirectional ? "unidirectional" : "bidirectional", j_allocator).Move(), + j_allocator); + return j; +} + +Document StreamsBlockedFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::STREAMS_BLOCKED_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("max_streams", + streamCount, + j_allocator); + j.AddMember("direction", + Value(unidirectional ? "unidirectional" : "bidirectional", j_allocator).Move(), + j_allocator); + return j; +} + +Document PathResponseFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::PATH_RESPONSE_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("path_data", + Value(pathData.c_str(), j_allocator).Move(), + j_allocator); + return j; +} + +Document PathChallengeFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::PATH_CHALLENGE_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("path_data", + Value(pathData.c_str(), j_allocator).Move(), + j_allocator); + return j; +} + +Document StopSendingFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::STOP_SENDING_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("stream_id", + streamId, + j_allocator); + j.AddMember("error_code", + Value(QuicRstStreamErrorCodeToString(errorCode), j_allocator).Move(), + j_allocator); + return j; +} + +Document MessageFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::MESSAGE_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("message_id", + messageId, + j_allocator); + j.AddMember("length", + length, + j_allocator); + return j; +} + +Document NewTokenFrameLog::toJson() const { + Document j; + j.SetObject(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::NEW_TOKEN_FRAME).data(), j.GetAllocator()).Move(), + j.GetAllocator()); + return j; +} + +Document RetireConnectionIdFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::RETIRE_CONNECTION_ID_FRAME).data(), j_allocator).Move(), + j_allocator); + j.AddMember("sequence", + sequenceNumber, + j_allocator); + return j; +} + +Document AckFrequencyFrameLog::toJson() const { + Document j; + j.SetObject(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.AddMember("frame_type", + Value(toQlogString(QuicFrameType::ACK_FREQUENCY_FRAME).data(), j_allocator ).Move(), + j_allocator); + j.AddMember("sequence_number", + sequenceNumber, + j_allocator); + j.AddMember("packet_tolerance", + packetTolerance, + j_allocator); + j.AddMember("update_max_ack_delay", + updateMaxAckDelay, + j_allocator); + j.AddMember("ignore_order", + ignoreOrder, + j_allocator); + return j; +} + +Document VersionNegotiationLog::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + for (const auto& v : versions) { + j.PushBack(Value(ParsedQuicVersionToString(v).c_str(), j_allocator).Move(), j_allocator); + } + return j; +} + +Document QLogFramesProcessed::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + + Value value; + value.SetObject(); + value.AddMember("frames_type", + Value(toQlogString(framesType).data(), j_allocator).Move(), + j_allocator); + + //frames_fields + Value tmp_arr; + tmp_arr.SetArray(); + switch(framesType) { + case QuicFrameType::STREAM_FRAME: + tmp_arr.PushBack("stream_id", j_allocator); + tmp_arr.PushBack("offset", j_allocator); + tmp_arr.PushBack("length", j_allocator); + tmp_arr.PushBack("fin", j_allocator); + break; + case QuicFrameType::ACK_FRAME: + tmp_arr.PushBack("acked_ranges", j_allocator); + tmp_arr.PushBack("ack_delay", j_allocator); + break; + default: + break; + } + value.AddMember("frames_fields", + tmp_arr, + j_allocator); + + tmp_arr.SetArray(); + for (const auto& frame : frames) { + tmp_arr.PushBack(Value().CopyFrom(frame->toShortJson(), j_allocator).Move(), j_allocator); + } + value.AddMember("frames", + tmp_arr, + j_allocator); + + tmp_arr.SetArray(); + for (const auto& packetSize : packetSizes) { + tmp_arr.PushBack(packetSize, j_allocator); + } + value.AddMember("packetSizes", + tmp_arr, + j_allocator); + + tmp_arr.SetArray(); + for (const auto& timeDrift : timeDrifts) { + tmp_arr.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(timeDrift.count()).c_str(), j_allocator).Move(), j_allocator); + } + value.AddMember("timeDrift", + tmp_arr, + j_allocator); + + + tmp_arr.SetArray(); + for (const auto& packetNum : packetNums) { + tmp_arr.PushBack(packetNum, j_allocator); + } + value.AddMember("packetNums", + tmp_arr, + j_allocator); + + value.AddMember("packet_type", + Value(packetType.c_str(), j_allocator).Move(), + j_allocator); + + j.PushBack(value, j_allocator); + return j; +} + +Document QLogPacketEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + + Value tmp_object, tmp_arr; + tmp_object.SetObject(); + tmp_arr.SetArray(); + tmp_object.AddMember("packet_size", + packetSize, + j_allocator); + if (packetType != toQlogString(QuicLongHeaderType::RETRY)) { + tmp_object.AddMember("packet_number", + packetNum, + j_allocator); + + for (const auto& frame : frames) { + tmp_arr.PushBack(Value().CopyFrom(frame->toJson(), j_allocator).Move(), j_allocator); + } + } + + Value value; + value.SetObject(); + value.AddMember("header", + tmp_object, + j_allocator); + value.AddMember("frames", + tmp_arr, + j_allocator); + value.AddMember("packet_type", + Value(packetType.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("transmission_type", + Value(transmissionType.c_str(), j_allocator).Move(), + j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +Document QLogVersionNegotiationEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value tmp_object; + tmp_object.SetObject(); + tmp_object.AddMember("packet_size", + packetSize, + j_allocator); + + Value value; + value.SetObject(); + value.AddMember("versions", + versionLog->toJson(), + j_allocator); + value.AddMember("header", + tmp_object, + j_allocator); + value.AddMember("packet_type", + Value(packetType.c_str(), j_allocator).Move(), + j_allocator); + + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +Document QLogRetryEvent::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value tmp_object; + tmp_object.SetObject(); + tmp_object.AddMember("packet_size", + packetSize, + j_allocator); + + Value value; + value.SetObject(); + value.AddMember("header", + tmp_object, + j_allocator); + value.AddMember("packet_type", + Value(packetType.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("token_size", + tokenSize, + j_allocator); + + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogConnectionCloseEvent::QLogConnectionCloseEvent( + QuicErrorCode errorIn, + std::string reasonIn, + ConnectionCloseSource sourceIn, + std::chrono::microseconds refTimeIn) + : error{std::move(errorIn)}, + reason{std::move(reasonIn)}, + source{sourceIn} { + eventType = QLogEventType::ConnectionClose; + refTime = refTimeIn; +} + +Document QLogConnectionCloseEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("error", + Value(QuicErrorCodeToString(error), j_allocator).Move(), + j_allocator); + value.AddMember("reason", + Value(reason.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("source", + Value(ConnectionCloseSourceToString(source).c_str(), j_allocator).Move(), + j_allocator); + + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("connectivity", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogTransportSummaryEvent::QLogTransportSummaryEvent( + uint64_t totalBytesSentIn, + uint64_t totalPacketsSentIn, + uint64_t totalBytesRecvdIn, + uint64_t totalPacketsRecvdIn, + uint64_t sumCurWriteOffsetIn, + uint64_t sumMaxObservedOffsetIn, + uint64_t sumCurStreamBufferLenIn, + uint64_t totalStartupDurationIn, + uint64_t totalDrainDurationIn, + uint64_t totalProbeBWDurationIn, + uint64_t totalProbeRttDurationIn, + uint64_t totalNotRecoveryDurationIn, + uint64_t totalGrowthDurationIn, + uint64_t totalConservationDurationIn, + uint64_t totalPacketsLostIn, + uint64_t totalStreamBytesClonedIn, + uint64_t totalBytesClonedIn, + uint64_t totalCryptoDataWrittenIn, + uint64_t totalCryptoDataRecvdIn, + uint64_t currentWritableBytesIn, + uint64_t currentConnFlowControlIn, + bool usedZeroRttIn, + QuicTransportVersion quicVersionIn, + CongestionControlType congestionTypeIn, + double smoothedMinRttIn, + double smoothedMaxBandwidthIn, + float startupDurationRatioIn, + float drainDurationRatioIn, + float probebwDurationRatioIn, + float proberttDurationRatioIn, + float NotRecoveryDurationRatioIn, + float GrowthDurationRatioIn, + float ConservationDurationRatioIn, + float AverageDifferenceIn, + std::chrono::microseconds refTimeIn) + : totalBytesSent{totalBytesSentIn}, + totalPacketsSent{totalPacketsSentIn}, + totalBytesRecvd{totalBytesRecvdIn}, + totalPacketsRecvd{totalPacketsRecvdIn}, + sumCurWriteOffset{sumCurWriteOffsetIn}, + sumMaxObservedOffset{sumMaxObservedOffsetIn}, + sumCurStreamBufferLen{sumCurStreamBufferLenIn}, + totalPacketsLost{totalPacketsLostIn}, + totalStartupDuration{totalStartupDurationIn}, + totalDrainDuration{totalDrainDurationIn}, + totalProbeBWDuration{totalProbeBWDurationIn}, + totalProbeRttDuration{totalProbeRttDurationIn}, + totalNotRecoveryDuration{totalNotRecoveryDurationIn}, + totalGrowthDuration{totalGrowthDurationIn}, + totalConservationDuration{totalConservationDurationIn}, + totalStreamBytesCloned{totalStreamBytesClonedIn}, + totalBytesCloned{totalBytesClonedIn}, + totalCryptoDataWritten{totalCryptoDataWrittenIn}, + totalCryptoDataRecvd{totalCryptoDataRecvdIn}, + currentWritableBytes{currentWritableBytesIn}, + currentConnFlowControl{currentConnFlowControlIn}, + usedZeroRtt{usedZeroRttIn}, + quicVersion{quicVersionIn}, + congestionType{congestionTypeIn}, + smoothedMinRtt{smoothedMinRttIn}, + smoothedMaxBandwidth{smoothedMaxBandwidthIn}, + startupDurationRatio{startupDurationRatioIn}, + drainDurationRatio{drainDurationRatioIn}, + probebwDurationRatio{probebwDurationRatioIn}, + proberttDurationRatio{proberttDurationRatioIn}, + NotRecoveryDurationRatio {NotRecoveryDurationRatioIn}, + GrowthDurationRatio {GrowthDurationRatioIn}, + ConservationDurationRatio {ConservationDurationRatioIn}, + AverageDifference {AverageDifferenceIn} { + eventType = QLogEventType::TransportSummary; + refTime = refTimeIn; +} + +Document QLogTransportSummaryEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("total_bytes_sent", + totalBytesSent, + j_allocator); + value.AddMember("reatotal_packets_sentson", + totalPacketsSent, + j_allocator); + value.AddMember("total_bytes_recvd", + totalBytesRecvd, + j_allocator); + value.AddMember("total_packets_recvd", + totalPacketsRecvd, + j_allocator); + value.AddMember("sum_cur_write_offset", + sumCurWriteOffset, + j_allocator); + value.AddMember("sum_max_observed_offset", + sumMaxObservedOffset, + j_allocator); + value.AddMember("sum_cur_stream_buffer_len", + sumCurStreamBufferLen, + j_allocator); + value.AddMember("total_packets_lost", + totalPacketsLost, + j_allocator); + value.AddMember("total_stream_bytes_cloned", + totalStreamBytesCloned, + j_allocator); + value.AddMember("total_bytes_cloned", + totalBytesCloned, + j_allocator); + value.AddMember("total_crypto_data_written", + totalCryptoDataWritten, + j_allocator); + value.AddMember("total_crypto_data_recvd", + totalCryptoDataRecvd, + j_allocator); + value.AddMember("current_writable_bytes", + currentWritableBytes, + j_allocator); + value.AddMember("current_conn_flow_control", + currentConnFlowControl, + j_allocator); + value.AddMember("used_zero_rtt", + usedZeroRtt, + j_allocator); + value.AddMember("quic_version", + Value(QuicVersionToString(quicVersion).c_str(), j_allocator).Move(), + j_allocator); + + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogBBRCongestionMetricUpdateEvent::QLogBBRCongestionMetricUpdateEvent( + uint64_t bytesInFlightIn, + uint64_t currentCwndIn, + std::string congestionEventIn, + CongestionControlType typeIn, + void* stateIn, + std::chrono::microseconds refTimeIn) + : bytesInFlight{bytesInFlightIn}, + currentCwnd{currentCwndIn}, + congestionEvent{std::move(congestionEventIn)}, + type{typeIn}, + state{stateIn} { + eventType = QLogEventType::CongestionMetricUpdate; + refTime = refTimeIn; +} + +Document QLogBBRCongestionMetricUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("bytes_in_flight", + bytesInFlight, + j_allocator); + value.AddMember("current_cwnd", + currentCwnd, + j_allocator); + value.AddMember("congestion_event", + Value(congestionEvent.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("congestion_control_type", + Value(toQlogString(type).data(), j_allocator).Move(), + j_allocator); + if (type == kBBR) { + BbrSender::DebugState* bbr_state = + static_cast(state); + Value tmp_object(kObjectType); + tmp_object.AddMember("mode", + Value(toQlogString(bbr_state->mode).data(), j_allocator).Move(), + j_allocator); + tmp_object.AddMember("max_bandwidth", + bbr_state->max_bandwidth.ToKBitsPerSecond(), + j_allocator); + tmp_object.AddMember("round_trip_counter", + bbr_state->round_trip_count, + j_allocator); + tmp_object.AddMember("gain_cycle_index", + static_cast(bbr_state->gain_cycle_index), + j_allocator); + tmp_object.AddMember("min_rtt", + bbr_state->min_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("latest_rtt", + bbr_state->latest_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("smoothed_rtt", + bbr_state->smoothed_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("mean_deviation", + bbr_state->mean_deviation.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("recovery_state", + Value(toQlogString(bbr_state->recovery_state).data(), j_allocator).Move(), + j_allocator); + value.AddMember("state", + tmp_object, + j_allocator); + } + + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("metric_update", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogCubicCongestionMetricUpdateEvent::QLogCubicCongestionMetricUpdateEvent( + uint64_t bytesInFlightIn, + uint64_t currentCwndIn, + std::string congestionEventIn, + CongestionControlType typeIn, + void* stateIn, + std::chrono::microseconds refTimeIn) + : bytesInFlight{bytesInFlightIn}, + currentCwnd{currentCwndIn}, + congestionEvent{std::move(congestionEventIn)}, + type{typeIn}, + state{stateIn} { + eventType = QLogEventType::CongestionMetricUpdate; + refTime = refTimeIn; +} + +Document QLogCubicCongestionMetricUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + + value.AddMember("bytes_in_flight", + bytesInFlight, + j_allocator); + value.AddMember("current_cwnd", + currentCwnd, + j_allocator); + value.AddMember("congestion_event", + Value(congestionEvent.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("congestion_control_type", + Value(toQlogString(type).data(), j_allocator).Move(), + j_allocator); + if (type == kCubicBytes) { + TcpCubicSenderBytes::DebugState* cubic_state = + static_cast(state); + Value tmp_object(kObjectType); + tmp_object.AddMember("min_rtt", + cubic_state->min_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("latest_rtt", + cubic_state->latest_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("smoothed_rtt", + cubic_state->smoothed_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("mean_deviation", + cubic_state->mean_deviation.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("bandwidth_est", + cubic_state->bandwidth_est.ToKBitsPerSecond(), + j_allocator); + value.AddMember("state", + tmp_object, + j_allocator); + + } + + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("metric_update", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogBBR2CongestionMetricUpdateEvent::QLogBBR2CongestionMetricUpdateEvent( + uint64_t bytesInFlightIn, + uint64_t currentCwndIn, + std::string congestionEventIn, + CongestionControlType typeIn, + void* stateIn, + std::chrono::microseconds refTimeIn) + : bytesInFlight{bytesInFlightIn}, + currentCwnd{currentCwndIn}, + congestionEvent{std::move(congestionEventIn)}, + type{typeIn}, + state{stateIn} { + eventType = QLogEventType::CongestionMetricUpdate; + refTime = refTimeIn; +} + +Document QLogBBR2CongestionMetricUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + Value value; + value.SetObject(); + value.AddMember("bytes_in_flight", + bytesInFlight, + j_allocator); + value.AddMember("current_cwnd", + currentCwnd, + j_allocator); + value.AddMember("congestion_event", + Value(congestionEvent.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("congestion_control_type", + Value(toQlogString(type).data(), j_allocator).Move(), + j_allocator); + + if (type == kBBRv2) { + Bbr2Sender::DebugState* bbr2_state = + static_cast(state); + Value tmp_object(kObjectType); + tmp_object.AddMember("mode", + Value(toQlogString(bbr2_state->mode).data(), j_allocator).Move(), + j_allocator); + tmp_object.AddMember("bandwidth_hi", + bbr2_state->bandwidth_hi.ToKBitsPerSecond(), + j_allocator); + tmp_object.AddMember("bandwidth_lo", + bbr2_state->bandwidth_lo.ToKBitsPerSecond(), + j_allocator); + tmp_object.AddMember("bandwidth_est", + bbr2_state->bandwidth_est.ToKBitsPerSecond(), + j_allocator); + tmp_object.AddMember("round_trip_counter", + bbr2_state->round_trip_count, + j_allocator); + tmp_object.AddMember("min_rtt", + bbr2_state->min_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("latest_rtt", + bbr2_state->latest_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("smoothed_rtt", + bbr2_state->smoothed_rtt.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("mean_deviation", + bbr2_state->mean_deviation.ToMicroseconds(), + j_allocator); + tmp_object.AddMember("inflight_hi", + bbr2_state->inflight_hi, + j_allocator); + tmp_object.AddMember("inflight_lo", + bbr2_state->inflight_lo, + j_allocator); + + value.AddMember("state", + tmp_object, + j_allocator); + } + + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("metric_update", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + + +QLogRequestOverStreamEvent::QLogRequestOverStreamEvent( + std::string methodIn, + QuicStreamId streamIdIn, + std::string uriIn, + std::string rangeIn, + std::chrono::microseconds refTimeIn) + : streamId{streamIdIn}, + method{std::move(methodIn)}, + uri{std::move(uriIn)}, + range{std::move(rangeIn)} { + eventType = QLogEventType::RequestOverStream; + refTime = refTimeIn; +} + +Document QLogRequestOverStreamEvent::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("stream_id", + streamId, + j_allocator); + value.AddMember("method", + Value(method.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("uri", + Value(uri.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("range", + Value(range.c_str(),j_allocator).Move(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("application", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogAppLimitedUpdateEvent::QLogAppLimitedUpdateEvent( + bool limitedIn, + std::chrono::microseconds refTimeIn) + : limited(limitedIn) { + eventType = QLogEventType::AppLimitedUpdate; + refTime = refTimeIn; +} + +Document QLogAppLimitedUpdateEvent::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("app_limited", + Value(limited ? kAppLimited : kAppUnlimited, j_allocator).Move(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("APP_LIMITED_UPDATE", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogBandwidthEstUpdateEvent::QLogBandwidthEstUpdateEvent( + uint64_t bytesIn, + std::chrono::microseconds intervalIn, + std::chrono::microseconds refTimeIn) + : bytes(bytesIn), interval(intervalIn) { + refTime = refTimeIn; + eventType = QLogEventType::BandwidthEstUpdate; +} + +Document QLogBandwidthEstUpdateEvent::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("bandwidth_bytes", + bytes, + j_allocator); + value.AddMember("bandwidth_interval", + interval.count(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("BANDIWDTH_EST_UPDATE", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPacingMetricUpdateEvent::QLogPacingMetricUpdateEvent( + uint64_t pacingBurstSizeIn, + std::chrono::microseconds pacingIntervalIn, + std::chrono::microseconds refTimeIn) + : pacingBurstSize{pacingBurstSizeIn}, pacingInterval{pacingIntervalIn} { + eventType = QLogEventType::PacingMetricUpdate; + refTime = refTimeIn; +} + +Document QLogPacingMetricUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("pacing_burst_size", + pacingBurstSize, + j_allocator); + value.AddMember("pacing_interval", + pacingInterval.count(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("metric_update", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPacingObservationEvent::QLogPacingObservationEvent( + std::string& actualIn, + std::string& expectIn, + std::string& conclusionIn, + std::chrono::microseconds refTimeIn) + : actual(actualIn), + expect(expectIn), + conclusion(conclusionIn) { + eventType = QLogEventType::PacingObservation; + refTime = refTimeIn; +} + +// TODO: Sad. I wanted moved all the string into the dynamic but this function +// is const. I think we should make all the toDynamic rvalue qualified since +// users are not supposed to use them after toJson() is called. +Document QLogPacingObservationEvent::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("actual_pacing_rate", + Value(actual.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("expect_pacing_rate", + Value(expect.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("conclusion", + Value(conclusion.c_str(), j_allocator).Move(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("metric_update", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogAppIdleUpdateEvent::QLogAppIdleUpdateEvent( + std::string& idleEventIn, + bool idleIn, + std::chrono::microseconds refTimeIn) + : idleEvent{idleEventIn}, idle{idleIn} { + eventType = QLogEventType::AppIdleUpdate; + refTime = refTimeIn; +} + +Document QLogAppIdleUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("idle_event", + Value(idleEvent.c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("idle", + idle, + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("idle_update", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPacketDropEvent::QLogPacketDropEvent( + size_t packetSizeIn, + std::string& dropReasonIn, + std::chrono::microseconds refTimeIn) + : packetSize{packetSizeIn}, dropReason{dropReasonIn} { + eventType = QLogEventType::PacketDrop; + refTime = refTimeIn; +} + +Document QLogPacketDropEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("packet_size", + packetSize, + j_allocator); + value.AddMember("drop_reason", + Value(dropReason.c_str(), j_allocator).Move(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("loss", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} // namespace quic + +QLogDatagramReceivedEvent::QLogDatagramReceivedEvent( + uint64_t dataLen, + std::chrono::microseconds refTimeIn) + : dataLen{dataLen} { + eventType = QLogEventType::DatagramReceived; + refTime = refTimeIn; +} + +Document QLogDatagramReceivedEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("data_len", + dataLen, + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogLossAlarmEvent::QLogLossAlarmEvent( + uint64_t largestSentIn, + uint64_t alarmCountIn, + uint64_t outstandingPacketsIn, + std::string& typeIn, + std::chrono::microseconds refTimeIn) + : largestSent{largestSentIn}, + alarmCount{alarmCountIn}, + outstandingPackets{outstandingPacketsIn}, + type{typeIn} { + eventType = QLogEventType::LossAlarm; + refTime = refTimeIn; +} + +Document QLogLossAlarmEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("largest_sent", + largestSent, + j_allocator); + value.AddMember("alarm_count", + alarmCount, + j_allocator); + value.AddMember("outstanding_packets", + outstandingPackets, + j_allocator); + value.AddMember("type", + Value(type.c_str(), j_allocator).Move(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("loss", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPacketLostEvent::QLogPacketLostEvent( + uint64_t lostPacketNumIn, + EncryptionLevel encryptionLevelIn, + TransmissionType transmissionTypeIn, + std::chrono::microseconds refTimeIn) + : lostPacketNum{lostPacketNumIn}, + encryptionLevel{encryptionLevelIn}, + transmissionType{transmissionTypeIn} { + eventType = QLogEventType::PacketLost; + refTime = refTimeIn; +} + +Document QLogPacketLostEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("lost_packet_num", + lostPacketNum, + j_allocator); + value.AddMember("encryption_level", + Value(toQlogString(encryptionLevel).data(), j_allocator).Move(), + j_allocator); + value.AddMember("transmission_type", + transmissionType, + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("loss", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogTransportStateUpdateEvent::QLogTransportStateUpdateEvent( + std::string& updateIn, + std::chrono::microseconds refTimeIn) + : update{updateIn} { + eventType = QLogEventType::TransportStateUpdate; + refTime = refTimeIn; +} + +Document QLogTransportStateUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("update", + Value(update.c_str(), j_allocator).Move(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPacketBufferedEvent::QLogPacketBufferedEvent( + uint64_t packetNumIn, + EncryptionLevel encryptionLevelIn, + uint64_t packetSizeIn, + std::chrono::microseconds refTimeIn) + : packetNum{packetNumIn}, + encryptionLevel{encryptionLevelIn}, + packetSize{packetSizeIn} { + eventType = QLogEventType::PacketBuffered; + refTime = refTimeIn; +} + +Document QLogPacketBufferedEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("packet_num", + packetNum, + j_allocator); + value.AddMember("encryption_level", + Value(toQlogString(encryptionLevel).data(), j_allocator).Move(), + j_allocator); + value.AddMember("packet_size", + packetSize, + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPacketAckEvent::QLogPacketAckEvent( + PacketNumberSpace packetNumSpaceIn, + uint64_t packetNumIn, + std::chrono::microseconds refTimeIn) + : packetNumSpace{packetNumSpaceIn}, packetNum{packetNumIn} { + eventType = QLogEventType::PacketAck; + refTime = refTimeIn; +} + +Document QLogPacketAckEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("packet_num_space", + Value(quiche::QuicheTextUtilsImpl::Uint64ToString(packetNumSpace).c_str(), j_allocator).Move(), + j_allocator); + value.AddMember("packet_num", + packetNum, + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogMetricUpdateEvent::QLogMetricUpdateEvent( + std::chrono::microseconds latestRttIn, + std::chrono::microseconds mrttIn, + std::chrono::microseconds srttIn, + std::chrono::microseconds ackDelayIn, + std::chrono::microseconds refTimeIn) + : latestRtt{latestRttIn}, mrtt{mrttIn}, srtt{srttIn}, ackDelay{ackDelayIn} { + eventType = QLogEventType::MetricUpdate; + refTime = refTimeIn; +} + +Document QLogMetricUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("latest_rtt", + latestRtt.count(), + j_allocator); + value.AddMember("min_rtt", + mrtt.count(), + j_allocator); + value.AddMember("smoothed_rtt", + srtt.count(), + j_allocator); + value.AddMember("ack_delay", + ackDelay.count(), + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("recovery", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogStreamStateUpdateEvent::QLogStreamStateUpdateEvent( + QuicStreamId idIn, + std::string& updateIn, + quiche::QuicheOptionalImpl timeSinceStreamCreationIn, + VantagePoint vantagePoint, + std::chrono::microseconds refTimeIn) + : id{idIn}, + update{updateIn}, + timeSinceStreamCreation(std::move(timeSinceStreamCreationIn)), + vantagePoint_(vantagePoint) { + eventType = QLogEventType::StreamStateUpdate; + refTime = refTimeIn; +} + +Document QLogStreamStateUpdateEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("id", + id, + j_allocator); + value.AddMember("update", + Value(update.c_str(), j_allocator).Move(), + j_allocator); + if (timeSinceStreamCreation) { + if (update == kOnEOM && vantagePoint_ == VantagePoint::IS_CLIENT) { + value.AddMember("ttlb", + timeSinceStreamCreation->count(), + j_allocator); + } else if (update == kOnHeaders && vantagePoint_ == VantagePoint::IS_CLIENT) { + value.AddMember("ttfb", + timeSinceStreamCreation->count(), + j_allocator); + } else { + value.AddMember("ms_since_creation", + timeSinceStreamCreation->count(), + j_allocator); + } + } + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("HTTP3", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogConnectionMigrationEvent::QLogConnectionMigrationEvent( + bool intentionalMigration, + VantagePoint vantagePoint, + std::chrono::microseconds refTimeIn) + : intentionalMigration_{intentionalMigration}, vantagePoint_(vantagePoint) { + eventType = QLogEventType::ConnectionMigration; + refTime = refTimeIn; +} + +Document QLogConnectionMigrationEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("intentional", + intentionalMigration_, + j_allocator); + if (vantagePoint_ == VantagePoint::IS_CLIENT) { + value.AddMember("type", + "initiating", + j_allocator); + } else { + value.AddMember("type", + "accepting", + j_allocator); + } + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPathValidationEvent::QLogPathValidationEvent( + bool success, + VantagePoint vantagePoint, + std::chrono::microseconds refTimeIn) + : success_{success}, vantagePoint_(vantagePoint) { + eventType = QLogEventType::PathValidation; + refTime = refTimeIn; +} + +Document QLogPathValidationEvent::toJson() const { + // creating a json array to hold the information corresponding to + // the event fields relative_time, category, event_type, trigger, data + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("success", + success_, + j_allocator); + if (vantagePoint_ == VantagePoint::IS_CLIENT) { + value.AddMember("vantagePoint", + "client", + j_allocator); + } else { + value.AddMember("vantagePoint", + "server", + j_allocator); + } + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("transport", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +QLogPriorityUpdateEvent::QLogPriorityUpdateEvent( + QuicStreamId streamId, + uint8_t urgency, + bool incremental, + std::chrono::microseconds refTimeIn) + : streamId_(streamId), urgency_(urgency), incremental_(incremental) { + eventType = QLogEventType::PriorityUpdate; + refTime = refTimeIn; +} + +Document QLogPriorityUpdateEvent::toJson() const { + Document j; + j.SetArray(); + Document::AllocatorType& j_allocator = j.GetAllocator(); + + Value value; + value.SetObject(); + value.AddMember("id", + streamId_, + j_allocator); + value.AddMember("urgency", + urgency_, + j_allocator); + value.AddMember("incremental", + "incremental_", + j_allocator); + j.PushBack(Value(quiche::QuicheTextUtilsImpl::Uint64ToString(refTime.count()).c_str(), j_allocator).Move(), j_allocator); + j.PushBack("HTTP3", j_allocator); + j.PushBack(Value(toString(eventType).data(), j_allocator).Move(), j_allocator); + j.PushBack(value, j_allocator); + return j; +} + +quiche::QuicheStringPiece toString(QLogEventType type) { + switch (type) { + case QLogEventType::PacketSent: + return "packet_sent"; + case QLogEventType::PacketReceived: + return "packet_received"; + case QLogEventType::ConnectionClose: + return "connection_close"; + case QLogEventType::TransportSummary: + return "transport_summary"; + case QLogEventType::CongestionMetricUpdate: + return "congestion_metric_update"; + case QLogEventType::PacingMetricUpdate: + return "pacing_metric_update"; + case QLogEventType::AppIdleUpdate: + return "app_idle_update"; + case QLogEventType::PacketDrop: + return "packet_drop"; + case QLogEventType::DatagramReceived: + return "datagram_received"; + case QLogEventType::LossAlarm: + return "loss_alarm"; + case QLogEventType::PacketLost: + return "packet_lost"; + case QLogEventType::TransportStateUpdate: + return "transport_state_update"; + case QLogEventType::PacketBuffered: + return "packet_buffered"; + case QLogEventType::PacketAck: + return "packet_ack"; + case QLogEventType::MetricUpdate: + return "metric_update"; + case QLogEventType::StreamStateUpdate: + return "stream_state_update"; + case QLogEventType::PacingObservation: + return "pacing_observation"; + case QLogEventType::AppLimitedUpdate: + return "app_limited_update"; + case QLogEventType::BandwidthEstUpdate: + return "bandwidth_est_update"; + case QLogEventType::ConnectionMigration: + return "connection_migration"; + case QLogEventType::PathValidation: + return "path_validation"; + case QLogEventType::PriorityUpdate: + return "priority"; + case QLogEventType::FramesProcessed: + return "frames_processed"; + case QLogEventType::RequestOverStream: + return "http_request"; + default: + return "unknown_event_type"; + } +} +} // namespace quic diff --git a/base/bvc-qlog/src/qlogger_types.h b/base/bvc-qlog/src/qlogger_types.h new file mode 100644 index 00000000..31187ab7 --- /dev/null +++ b/base/bvc-qlog/src/qlogger_types.h @@ -0,0 +1,856 @@ +#pragma once + +#include +#include +#include + +#include "base/bvc-qlog/src/qlogger_constants.h" +#include "platform/quiche_platform_impl/quiche_text_utils_impl.h" +#include "gquiche/quic/core/quic_stream.h" +#include "gquiche/quic/core/quic_packets.h" +#include "gquiche/quic/core/quic_versions.h" +#include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/core/quic_error_codes.h" +#include "gquiche/quic/core/frames/quic_ack_frame.h" + +#include "third_party/rapidjson/include/rapidjson/document.h" +#include "third_party/rapidjson/include/rapidjson/writer.h" +#include "third_party/rapidjson/include/rapidjson/stringbuffer.h" + +namespace quic { +/** + * Application error codes are opaque to QUIC transport. Each application + * protocol can define its own error codes. + */ +using namespace rapidjson; + +class QLogFrame { + public: + QLogFrame() = default; + virtual ~QLogFrame() = default; + virtual Document toJson() const = 0; + virtual Document toShortJson() const; +}; + +class PaddingFrameLog : public QLogFrame { + public: + PaddingFrameLog() = default; + ~PaddingFrameLog() override = default; + Document toJson() const override; + Document toShortJson() const override; +}; + +class RstStreamFrameLog : public QLogFrame { + public: + QuicStreamId streamId; + QuicRstStreamErrorCode errorCode; + uint64_t offset; + + RstStreamFrameLog( + QuicStreamId streamIdIn, + QuicRstStreamErrorCode errorCodeIn, + QuicStreamOffset offsetIn) + : streamId{streamIdIn}, errorCode{errorCodeIn}, offset{offsetIn} {} + + ~RstStreamFrameLog() override = default; + Document toJson() const override; +}; + +class ConnectionCloseFrameLog : public QLogFrame { + public: + QuicConnectionCloseType closeType; + uint64_t wireErrorCode; + QuicErrorCode quicErrorCode; + std::string errorDetails; + uint64_t transportCloseFrameType; + + ConnectionCloseFrameLog( + QuicConnectionCloseType closeTypeIn, + uint64_t wireErrorCodeIn, + QuicErrorCode quicErrorCodeIn, + std::string errorDetailsIn, + uint64_t transportCloseFrameTypeIn) + : closeType{closeTypeIn}, + wireErrorCode{wireErrorCodeIn}, + quicErrorCode{quicErrorCodeIn}, + errorDetails(errorDetailsIn), + transportCloseFrameType{transportCloseFrameTypeIn} {} + + ~ConnectionCloseFrameLog() override = default; + Document toJson() const override; +}; + +class GoAwayFrameLog : public QLogFrame { + public: + QuicErrorCode errorCode; + QuicStreamId lastGoodStreamId; + std::string reasonPhrase; + + GoAwayFrameLog( + QuicErrorCode errorCodeIn, + QuicStreamId lastGoodStreamIdIn, + std::string reasonPhraseIn) + : errorCode{errorCodeIn}, + lastGoodStreamId{lastGoodStreamIdIn}, + reasonPhrase{reasonPhraseIn} {} + + ~GoAwayFrameLog() override = default; + Document toJson() const override; +}; + +class WindowUpdateFrameLog : public QLogFrame { + public: + QuicStreamId streamId; + uint64_t maxData; + + WindowUpdateFrameLog( + QuicStreamId streamIdIn, + uint64_t maxDataIn) + : streamId{streamIdIn}, + maxData{maxDataIn} {} + + ~WindowUpdateFrameLog() override = default; + Document toJson() const override; +}; + +class BlockedFrameLog : public QLogFrame { + public: + QuicStreamId streamId; + + BlockedFrameLog(QuicStreamId streamIdIn) + : streamId{streamIdIn} {} + + ~BlockedFrameLog() override = default; + Document toJson() const override; +}; + +class StopWaitingFrameLog : public QLogFrame { + public: + StopWaitingFrameLog() = default; + ~StopWaitingFrameLog() override = default; + Document toJson() const override; +}; + +class PingFrameLog : public QLogFrame { + public: + PingFrameLog() = default; + ~PingFrameLog() override = default; + Document toJson() const override; +}; + +class AckFrameLog : public QLogFrame { + public: + PacketNumberQueue packetNumberQueue; + std::chrono::microseconds ackDelay; + + AckFrameLog( + const PacketNumberQueue packetNumberQueueIn, + uint64_t ackDelayIn) + : packetNumberQueue{packetNumberQueueIn}, ackDelay{ackDelayIn} {} + ~AckFrameLog() override = default; + Document toJson() const override; + Document toShortJson() const override; +}; + +class StreamFrameLog : public QLogFrame { + public: + QuicStreamId streamId; + uint64_t offset; + uint64_t len; + bool fin; + + StreamFrameLog( + QuicStreamId streamIdIn, + uint64_t offsetIn, + uint64_t lenIn, + bool finIn) + : streamId{streamIdIn}, offset{offsetIn}, len{lenIn}, fin{finIn} {} + ~StreamFrameLog() override = default; + + Document toJson() const override; + Document toShortJson() const; +}; + +class CryptoFrameLog : public QLogFrame { + public: + EncryptionLevel level; + uint64_t offset; + uint64_t dataLength; + + CryptoFrameLog( + EncryptionLevel levelIn, + uint64_t offsetIn, + uint64_t dataLengthIn) + : level{levelIn}, offset{offsetIn}, dataLength{dataLengthIn} {} + ~CryptoFrameLog() override = default; + Document toJson() const override; +}; + +class HandshakeDoneFrameLog : public QLogFrame { + public: + HandshakeDoneFrameLog() = default; + ~HandshakeDoneFrameLog() override = default; + Document toJson() const override; +}; + +class MTUDiscoveryFrameLog : public QLogFrame { + public: + MTUDiscoveryFrameLog() = default; + ~MTUDiscoveryFrameLog() override = default; + Document toJson() const override; +}; + +class NewConnectionIdFrameLog : public QLogFrame { + public: + std::string newConnectionId; + uint64_t sequenceNumber; + + NewConnectionIdFrameLog( + std::string newConnectionIdIn, + uint64_t sequenceNumberIn) + : newConnectionId{newConnectionIdIn}, sequenceNumber{sequenceNumberIn} {} + + ~NewConnectionIdFrameLog() override = default; + Document toJson() const override; +}; + +class MaxStreamsFrameLog : public QLogFrame { + public: + uint64_t streamCount; + bool unidirectional; + + MaxStreamsFrameLog( + uint64_t streamCountIn, + bool unidirectionalIn) + : streamCount{streamCountIn}, unidirectional{unidirectionalIn} {} + + ~MaxStreamsFrameLog() override = default; + Document toJson() const override; +}; + +class StreamsBlockedFrameLog : public QLogFrame { + public: + uint64_t streamCount; + bool unidirectional; + + StreamsBlockedFrameLog( + uint64_t streamCountIn, + bool unidirectionalIn) + : streamCount{streamCountIn}, unidirectional{unidirectionalIn} {} + + ~StreamsBlockedFrameLog() override = default; + Document toJson() const override; +}; + +class PathResponseFrameLog : public QLogFrame { + public: + std::string pathData; + + explicit PathResponseFrameLog(std::string pathDataIn) : pathData{pathDataIn} {} + ~PathResponseFrameLog() override = default; + Document toJson() const override; +}; + +class PathChallengeFrameLog : public QLogFrame { + public: + std::string pathData; + + explicit PathChallengeFrameLog(std::string pathDataIn) : pathData{pathDataIn} {} + ~PathChallengeFrameLog() override = default; + Document toJson() const override; +}; + +class StopSendingFrameLog : public QLogFrame { + public: + QuicStreamId streamId; + QuicRstStreamErrorCode errorCode; + + StopSendingFrameLog(QuicStreamId streamIdIn, QuicRstStreamErrorCode errorCodeIn) + : streamId{streamIdIn}, errorCode{errorCodeIn} {} + ~StopSendingFrameLog() override = default; + Document toJson() const override; +}; + +class MessageFrameLog : public QLogFrame { + public: + uint32_t messageId; + uint64_t length; + + MessageFrameLog(uint32_t messageIdIn, uint64_t lengthIn) + : messageId{messageIdIn}, length{lengthIn} {} + ~MessageFrameLog() override = default; + Document toJson() const override; +}; + +class NewTokenFrameLog : public QLogFrame { + public: + NewTokenFrameLog() = default; + ~NewTokenFrameLog() override = default; + Document toJson() const override; +}; + +class RetireConnectionIdFrameLog : public QLogFrame { + public: + uint64_t sequenceNumber; + + RetireConnectionIdFrameLog(uint64_t sequenceNumberIn) + : sequenceNumber(sequenceNumberIn) {} + + ~RetireConnectionIdFrameLog() override = default; + Document toJson() const override; +}; + +class AckFrequencyFrameLog : public QLogFrame { + public: + uint64_t sequenceNumber; + uint64_t packetTolerance; + uint64_t updateMaxAckDelay; + bool ignoreOrder; + + explicit AckFrequencyFrameLog( + uint64_t sequenceNumberIn, + uint64_t packetToleranceIn, + uint64_t updateMaxAckDelayIn, + bool ignoreOrderIn) + : sequenceNumber(sequenceNumberIn), + packetTolerance(packetToleranceIn), + updateMaxAckDelay(updateMaxAckDelayIn), + ignoreOrder(ignoreOrderIn) {} + ~AckFrequencyFrameLog() override = default; + Document toJson() const override; +}; + +class VersionNegotiationLog { + public: + std::vector versions; + + explicit VersionNegotiationLog(const std::vector& versionsIn) + : versions{versionsIn} {} + ~VersionNegotiationLog() = default; + Document toJson() const; +}; + +enum class QLogEventType : uint32_t { + PacketReceived, + PacketSent, + ConnectionClose, + TransportSummary, + CongestionMetricUpdate, + PacingMetricUpdate, + AppIdleUpdate, + PacketDrop, + DatagramReceived, + LossAlarm, + PacketLost, + TransportStateUpdate, + PacketBuffered, + PacketAck, + MetricUpdate, + StreamStateUpdate, + PacingObservation, + AppLimitedUpdate, + BandwidthEstUpdate, + ConnectionMigration, + PathValidation, + PriorityUpdate, + FramesProcessed, + RequestOverStream +}; + +quiche::QuicheStringPiece toString(QLogEventType type); + +class QLogEvent { + public: + QLogEvent() = default; + virtual ~QLogEvent() = default; + virtual Document toJson() const = 0; + std::chrono::microseconds refTime; + QLogEventType eventType; +}; + +class QLogFramesProcessed : public QLogEvent { + public: + QLogFramesProcessed() = default; + ~QLogFramesProcessed() override = default; + quiche::QuicheStringPiece weaver; + QuicFrameType framesType; + std::vector> frames; + std::vector packetSizes; + std::vector packetNums; + std::vector timeDrifts; + std::string packetType; + Document toJson() const override; +}; + +class QLogPacketEvent : public QLogEvent { + public: + QLogPacketEvent() = default; + ~QLogPacketEvent() override = default; + std::vector> frames; + std::string packetType; + std::string transmissionType; + uint64_t packetNum{0}; + uint64_t packetSize{0}; + Document toJson() const override; +}; + +class QLogVersionNegotiationEvent : public QLogEvent { + public: + QLogVersionNegotiationEvent() = default; + ~QLogVersionNegotiationEvent() override = default; + std::unique_ptr versionLog; + std::string packetType; + uint64_t packetSize{0}; + + Document toJson() const override; +}; + +class QLogRetryEvent : public QLogEvent { + public: + QLogRetryEvent() = default; + ~QLogRetryEvent() override = default; + + std::string packetType; + uint64_t packetSize{0}; + uint64_t tokenSize{0}; + + Document toJson() const override; +}; + +class QLogConnectionCloseEvent : public QLogEvent { + public: + QLogConnectionCloseEvent( + QuicErrorCode errorIn, + std::string reasonIn, + ConnectionCloseSource sourceIn, + std::chrono::microseconds refTimeIn); + ~QLogConnectionCloseEvent() override = default; + QuicErrorCode error; + std::string reason; + ConnectionCloseSource source; + + Document toJson() const override; +}; + +struct TransportSummaryArgs { + uint64_t totalBytesSent{}; + uint64_t totalBytesRecvd{}; + uint64_t sumCurWriteOffset{}; + uint64_t sumMaxObservedOffset{}; + uint64_t sumCurStreamBufferLen{}; + uint64_t totalPacketsLost{}; + uint64_t totalStartupDuration{}; + uint64_t totalDrainDuration{}; + uint64_t totalProbeBWDuration{}; + uint64_t totalProbeRttDuration{}; + uint64_t totalNotRecoveryDuration{}; + uint64_t totalGrowthDuration{}; + uint64_t totalConservationDuration{}; + uint64_t totalStreamBytesCloned{}; + uint64_t totalBytesCloned{}; + uint64_t totalCryptoDataWritten{}; + uint64_t totalCryptoDataRecvd{}; + uint64_t currentWritableBytes{}; + uint64_t currentConnFlowControl{}; + bool usedZeroRtt{}; + double smoothedMinRtt{}; + double smoothedMaxBandwidth{}; + float startupDurationRatio{}; + float drainDurationRatio{}; + float probebwDurationRatio{}; + float proberttDurationRatio{}; + float NotRecoveryDurationRatio{}; + float GrowthDurationRatio{}; + float ConservationDurationRatio{}; + float AverageDifference{}; +}; + +class QLogTransportSummaryEvent : public QLogEvent { + public: + QLogTransportSummaryEvent( + uint64_t totalBytesSent, + uint64_t totalPacketsSent, + uint64_t totalBytesRecvd, + uint64_t totalPacketsRecvd, + uint64_t sumCurWriteOffset, + uint64_t sumMaxObservedOffset, + uint64_t sumCurStreamBufferLen, + uint64_t totalPacketsLost, + uint64_t totalStartupDuration, + uint64_t totalDrainDuration, + uint64_t totalProbeBWDuration, + uint64_t totalProbeRttDuration, + uint64_t totalNotRecoveryDuration, + uint64_t totalGrowthDuration, + uint64_t totalConservationDuration, + uint64_t totalStreamBytesCloned, + uint64_t totalBytesCloned, + uint64_t totalCryptoDataWritten, + uint64_t totalCryptoDataRecvd, + uint64_t currentWritableBytes, + uint64_t currentConnFlowControl, + bool usedZeroRtt, + QuicTransportVersion version, + CongestionControlType congestionType, + double smoothedMinRtt, + double smoothedMaxBandwidth, + float startupDurationRatio, + float drainDurationRatio, + float probebwDurationRatio, + float proberttDurationRatio, + float NotRecoveryDurationRatio, + float GrowthDurationRatio, + float ConservationDurationRatio, + float AverageDifference, + std::chrono::microseconds refTime); + ~QLogTransportSummaryEvent() override = default; + uint64_t totalBytesSent; + uint64_t totalPacketsSent; + uint64_t totalBytesRecvd; + uint64_t totalPacketsRecvd; + uint64_t sumCurWriteOffset; + uint64_t sumMaxObservedOffset; + uint64_t sumCurStreamBufferLen; + uint64_t totalPacketsLost; + uint64_t totalStartupDuration; + uint64_t totalDrainDuration; + uint64_t totalProbeBWDuration; + uint64_t totalProbeRttDuration; + uint64_t totalNotRecoveryDuration; + uint64_t totalGrowthDuration; + uint64_t totalConservationDuration; + uint64_t totalStreamBytesCloned; + uint64_t totalBytesCloned; + uint64_t totalCryptoDataWritten; + uint64_t totalCryptoDataRecvd; + uint64_t currentWritableBytes; + uint64_t currentConnFlowControl; + bool usedZeroRtt; + QuicTransportVersion quicVersion; + CongestionControlType congestionType; + double smoothedMinRtt; + double smoothedMaxBandwidth; + float startupDurationRatio; + float drainDurationRatio; + float probebwDurationRatio; + float proberttDurationRatio; + float NotRecoveryDurationRatio; + float GrowthDurationRatio; + float ConservationDurationRatio; + float AverageDifference; + Document toJson() const override; +}; + +class QLogBBRCongestionMetricUpdateEvent : public QLogEvent { + public: + QLogBBRCongestionMetricUpdateEvent( + uint64_t bytesInFlightIn, + uint64_t currentCwndIn, + std::string congestionEventIn, + quic::CongestionControlType typeIn, + void* stateIn, + std::chrono::microseconds refTimeIn); + ~QLogBBRCongestionMetricUpdateEvent() override = default; + uint64_t bytesInFlight; + uint64_t currentCwnd; + std::string congestionEvent; + quic::CongestionControlType type; + void* state; + + Document toJson() const override; +}; + +class QLogCubicCongestionMetricUpdateEvent : public QLogEvent { + public: + QLogCubicCongestionMetricUpdateEvent( + uint64_t bytesInFlightIn, + uint64_t currentCwndIn, + std::string congestionEventIn, + quic::CongestionControlType typeIn, + void* stateIn, + std::chrono::microseconds refTimeIn); + ~QLogCubicCongestionMetricUpdateEvent() override = default; + uint64_t bytesInFlight; + uint64_t currentCwnd; + std::string congestionEvent; + quic::CongestionControlType type; + void* state; + Document toJson() const override; +}; + +class QLogBBR2CongestionMetricUpdateEvent : public QLogEvent { + public: + QLogBBR2CongestionMetricUpdateEvent( + uint64_t bytesInFlightIn, + uint64_t currentCwndIn, + std::string congestionEventIn, + quic::CongestionControlType typeIn, + void* stateIn, + std::chrono::microseconds refTimeIn); + ~QLogBBR2CongestionMetricUpdateEvent() override = default; + uint64_t bytesInFlight; + uint64_t currentCwnd; + std::string congestionEvent; + quic::CongestionControlType type; + void* state; + Document toJson() const override; +}; + +class QLogAppLimitedUpdateEvent : public QLogEvent { + public: + explicit QLogAppLimitedUpdateEvent( + bool limitedIn, + std::chrono::microseconds refTimeIn); + ~QLogAppLimitedUpdateEvent() override = default; + + Document toJson() const override; + + bool limited; +}; + +class QLogBandwidthEstUpdateEvent : public QLogEvent { + public: + explicit QLogBandwidthEstUpdateEvent( + uint64_t bytes, + std::chrono::microseconds interval, + std::chrono::microseconds refTimeIn); + ~QLogBandwidthEstUpdateEvent() override = default; + + Document toJson() const override; + + uint64_t bytes; + std::chrono::microseconds interval; +}; + +class QLogPacingMetricUpdateEvent : public QLogEvent { + public: + QLogPacingMetricUpdateEvent( + uint64_t pacingBurstSize, + std::chrono::microseconds pacingInterval, + std::chrono::microseconds refTime); + ~QLogPacingMetricUpdateEvent() override = default; + uint64_t pacingBurstSize; + std::chrono::microseconds pacingInterval; + + Document toJson() const override; +}; + +class QLogPacingObservationEvent : public QLogEvent { + public: + QLogPacingObservationEvent( + std::string& actualIn, + std::string& expectIn, + std::string& conclusionIn, + std::chrono::microseconds refTimeIn); + std::string actual; + std::string expect; + std::string conclusion; + + ~QLogPacingObservationEvent() override = default; + Document toJson() const override; +}; + +class QLogAppIdleUpdateEvent : public QLogEvent { + public: + QLogAppIdleUpdateEvent( + std::string& idleEvent, + bool idle, + std::chrono::microseconds refTime); + ~QLogAppIdleUpdateEvent() override = default; + std::string idleEvent; + bool idle; + + Document toJson() const override; +}; + +class QLogPacketDropEvent : public QLogEvent { + public: + QLogPacketDropEvent( + size_t packetSize, + std::string& dropReason, + std::chrono::microseconds refTime); + ~QLogPacketDropEvent() override = default; + size_t packetSize; + std::string dropReason; + + Document toJson() const override; +}; + +class QLogDatagramReceivedEvent : public QLogEvent { + public: + QLogDatagramReceivedEvent( + uint64_t dataLen, + std::chrono::microseconds refTime); + ~QLogDatagramReceivedEvent() override = default; + uint64_t dataLen; + + Document toJson() const override; +}; + +class QLogLossAlarmEvent : public QLogEvent { + public: + QLogLossAlarmEvent( + uint64_t largestSent, + uint64_t alarmCount, + uint64_t outstandingPackets, + std::string& type, + std::chrono::microseconds refTime); + ~QLogLossAlarmEvent() override = default; + uint64_t largestSent; + uint64_t alarmCount; + uint64_t outstandingPackets; + std::string type; + Document toJson() const override; +}; + +class QLogPacketLostEvent : public QLogEvent { + public: + QLogPacketLostEvent( + uint64_t LostPacketNum, + EncryptionLevel level, + TransmissionType type, + std::chrono::microseconds refTime); + ~QLogPacketLostEvent() override = default; + uint64_t lostPacketNum; + EncryptionLevel encryptionLevel; + TransmissionType transmissionType; + Document toJson() const override; +}; + +class QLogTransportStateUpdateEvent : public QLogEvent { + public: + QLogTransportStateUpdateEvent( + std::string& update, + std::chrono::microseconds refTime); + ~QLogTransportStateUpdateEvent() override = default; + std::string update; + Document toJson() const override; +}; + +class QLogPacketBufferedEvent : public QLogEvent { + public: + QLogPacketBufferedEvent( + uint64_t packetNum, + EncryptionLevel encryptionLevel, + uint64_t packetSize, + std::chrono::microseconds refTime); + ~QLogPacketBufferedEvent() override = default; + uint64_t packetNum; + EncryptionLevel encryptionLevel; + uint64_t packetSize; + Document toJson() const override; +}; + +class QLogPacketAckEvent : public QLogEvent { + public: + QLogPacketAckEvent( + PacketNumberSpace packetNumSpace, + uint64_t packetNum, + std::chrono::microseconds refTime); + ~QLogPacketAckEvent() override = default; + PacketNumberSpace packetNumSpace; + uint64_t packetNum; + Document toJson() const override; +}; + +class QLogMetricUpdateEvent : public QLogEvent { + public: + QLogMetricUpdateEvent( + std::chrono::microseconds latestRtt, + std::chrono::microseconds mrtt, + std::chrono::microseconds srtt, + std::chrono::microseconds ackDelay, + std::chrono::microseconds refTime); + ~QLogMetricUpdateEvent() override = default; + std::chrono::microseconds latestRtt; + std::chrono::microseconds mrtt; + std::chrono::microseconds srtt; + std::chrono::microseconds ackDelay; + Document toJson() const override; +}; + +class QLogStreamStateUpdateEvent : public QLogEvent { + public: + QLogStreamStateUpdateEvent( + QuicStreamId id, + std::string& update, + quiche::QuicheOptionalImpl timeSinceStreamCreation, + VantagePoint vantagePoint, + std::chrono::microseconds refTime); + ~QLogStreamStateUpdateEvent() override = default; + QuicStreamId id; + std::string update; + quiche::QuicheOptionalImpl timeSinceStreamCreation; + Document toJson() const override; + + private: + VantagePoint vantagePoint_; +}; + +class QLogConnectionMigrationEvent : public QLogEvent { + public: + QLogConnectionMigrationEvent( + bool intentionalMigration, + VantagePoint vantagePoint, + std::chrono::microseconds refTime); + + ~QLogConnectionMigrationEvent() override = default; + + Document toJson() const override; + + bool intentionalMigration_; + VantagePoint vantagePoint_; +}; + +class QLogPathValidationEvent : public QLogEvent { + public: + // The VantagePoint represents who initiates the path validation (sends out + // Path Challenge). + QLogPathValidationEvent( + bool success, + VantagePoint vantagePoint, + std::chrono::microseconds refTime); + + ~QLogPathValidationEvent() override = default; + + Document toJson() const override; + bool success_; + VantagePoint vantagePoint_; +}; + +class QLogPriorityUpdateEvent : public QLogEvent { + public: + explicit QLogPriorityUpdateEvent( + QuicStreamId id, + uint8_t urgency, + bool incremental, + std::chrono::microseconds refTimeIn); + ~QLogPriorityUpdateEvent() override = default; + + Document toJson() const override; + + private: + QuicStreamId streamId_; + uint8_t urgency_; + bool incremental_; +}; + +class QLogRequestOverStreamEvent : public QLogEvent { + public: + QLogRequestOverStreamEvent( + std::string methodIn, + QuicStreamId streamIdIn, + std::string uriIn, + std::string rangeIn, + std::chrono::microseconds refTimeIn); + ~QLogRequestOverStreamEvent() override = default; + + QuicStreamId streamId; + std::string method; + std::string uri; + std::string range; + + Document toJson() const override; +}; + +} // namespace quic diff --git a/base/files/file_path.cc b/base/files/file_path.cc index 4246d622..6d60148c 100644 --- a/base/files/file_path.cc +++ b/base/files/file_path.cc @@ -1,232 +1,232 @@ -#include "base/files/file_path.h" - -namespace base { -using StringType = FilePath::StringType; -using StringPieceType = FilePath::StringPieceType; - -namespace { - -const FilePath::CharType kStringTerminator = FILE_PATH_LITERAL('\0'); -// If this FilePath contains a drive letter specification, returns the -// position of the last character of the drive letter specification, -// otherwise returns npos. This can only be true on Windows, when a pathname -// begins with a letter followed by a colon. On other platforms, this always -// returns npos. -StringPieceType::size_type FindDriveLetter(StringPieceType path) { -#if defined(FILE_PATH_USES_DRIVE_LETTERS) - // This is dependent on an ASCII-based character set, but that's a - // reasonable assumption. iswalpha can be too inclusive here. - if (path.length() >= 2 && path[1] == L':' && - ((path[0] >= L'A' && path[0] <= L'Z') || - (path[0] >= L'a' && path[0] <= L'z'))) { - return 1; - } -#endif // FILE_PATH_USES_DRIVE_LETTERS - return StringType::npos; -} - -bool AreAllSeparators(const StringType& input) { - for (auto it : input) { - if (!FilePath::IsSeparator(it)) - return false; - } - - return true; -} -} // namespace - -FilePath::FilePath() = default; - -FilePath::FilePath(const FilePath& that) = default; -FilePath::FilePath(FilePath&& that) noexcept = default; - -FilePath::FilePath(StringPieceType path) : path_(path) { - StringType::size_type nul_pos = path_.find(kStringTerminator); - if (nul_pos != StringType::npos) - path_.erase(nul_pos, StringType::npos); -} - -FilePath::~FilePath() = default; - -FilePath& FilePath::operator=(const FilePath& that) = default; - -FilePath& FilePath::operator=(FilePath&& that) noexcept = default; - -bool FilePath::operator==(const FilePath& that) const { -#if defined(FILE_PATH_USES_DRIVE_LETTERS) - return EqualDriveLetterCaseInsensitive(this->path_, that.path_); -#else // defined(FILE_PATH_USES_DRIVE_LETTERS) - return path_ == that.path_; -#endif // defined(FILE_PATH_USES_DRIVE_LETTERS) -} - -bool FilePath::operator!=(const FilePath& that) const { -#if defined(FILE_PATH_USES_DRIVE_LETTERS) - return !EqualDriveLetterCaseInsensitive(this->path_, that.path_); -#else // defined(FILE_PATH_USES_DRIVE_LETTERS) - return path_ != that.path_; -#endif // defined(FILE_PATH_USES_DRIVE_LETTERS) -} - -std::ostream& operator<<(std::ostream& out, const FilePath& file_path) { - return out << file_path.value(); -} - -// libgen's dirname and basename aren't guaranteed to be thread-safe and aren't -// guaranteed to not modify their input strings, and in fact are implemented -// differently in this regard on different platforms. Don't use them, but -// adhere to their behavior. -FilePath FilePath::DirName() const { - FilePath new_path(path_); - new_path.StripTrailingSeparatorsInternal(); - - // The drive letter, if any, always needs to remain in the output. If there - // is no drive letter, as will always be the case on platforms which do not - // support drive letters, letter will be npos, or -1, so the comparisons and - // resizes below using letter will still be valid. - StringType::size_type letter = FindDriveLetter(new_path.path_); - - StringType::size_type last_separator = - new_path.path_.find_last_of(kSeparators, StringType::npos, - kSeparatorsLength - 1); - if (last_separator == StringType::npos) { - // path_ is in the current directory. - new_path.path_.resize(letter + 1); - } else if (last_separator == letter + 1) { - // path_ is in the root directory. - new_path.path_.resize(letter + 2); - } else if (last_separator == letter + 2 && - IsSeparator(new_path.path_[letter + 1])) { - // path_ is in "//" (possibly with a drive letter); leave the double - // separator intact indicating alternate root. - new_path.path_.resize(letter + 3); - } else if (last_separator != 0) { - // path_ is somewhere else, trim the basename. - new_path.path_.resize(last_separator); - } - - new_path.StripTrailingSeparatorsInternal(); - if (!new_path.path_.length()) - new_path.path_ = kCurrentDirectory; - - return new_path; -} - -FilePath FilePath::BaseName() const { - FilePath new_path(path_); - new_path.StripTrailingSeparatorsInternal(); - - // The drive letter, if any, is always stripped. - StringType::size_type letter = FindDriveLetter(new_path.path_); - if (letter != StringType::npos) { - new_path.path_.erase(0, letter + 1); - } - - // Keep everything after the final separator, but if the pathname is only - // one character and it's a separator, leave it alone. - StringType::size_type last_separator = - new_path.path_.find_last_of(kSeparators, StringType::npos, - kSeparatorsLength - 1); - if (last_separator != StringType::npos && - last_separator < new_path.path_.length() - 1) { - new_path.path_.erase(0, last_separator + 1); - } - - return new_path; -} - -// static -bool FilePath::IsSeparator(CharType character) { - for (size_t i = 0; i < kSeparatorsLength - 1; ++i) { - if (character == kSeparators[i]) { - return true; - } - } - - return false; -} - -void FilePath::GetComponents(std::vector* components) const { - QUICHE_DCHECK(components); - if (!components) - return; - components->clear(); - if (value().empty()) - return; - - std::vector ret_val; - FilePath current = *this; - FilePath base; - - // Capture path components. - while (current != current.DirName()) { - base = current.BaseName(); - if (!AreAllSeparators(base.value())) - ret_val.push_back(base.value()); - current = current.DirName(); - } - - // Capture root, if any. - base = current.BaseName(); - if (!base.value().empty() && base.value() != kCurrentDirectory) - ret_val.push_back(current.BaseName().value()); - - // Capture drive letter, if any. - FilePath dir = current.DirName(); - StringType::size_type letter = FindDriveLetter(dir.value()); - if (letter != StringType::npos) { - ret_val.push_back(StringType(dir.value(), 0, letter + 1)); - } - - *components = std::vector(ret_val.rbegin(), ret_val.rend()); -} - -bool FilePath::ReferencesParent() const { - if (path_.find(kParentDirectory) == StringType::npos) { - // GetComponents is quite expensive, so avoid calling it in the majority - // of cases where there isn't a kParentDirectory anywhere in the path. - return false; - } - - std::vector components; - GetComponents(&components); - - std::vector::const_iterator it = components.begin(); - for (; it != components.end(); ++it) { - const StringType& component = *it; - // Windows has odd, undocumented behavior with path components containing - // only whitespace and . characters. So, if all we see is . and - // whitespace, then we treat any .. sequence as referencing parent. - // For simplicity we enforce this on all platforms. - if (component.find_first_not_of(FILE_PATH_LITERAL(". \n\r\t")) == - std::string::npos && - component.find(kParentDirectory) != std::string::npos) { - return true; - } - } - return false; -} - - -void FilePath::StripTrailingSeparatorsInternal() { - // If there is no drive letter, start will be 1, which will prevent stripping - // the leading separator if there is only one separator. If there is a drive - // letter, start will be set appropriately to prevent stripping the first - // separator following the drive letter, if a separator immediately follows - // the drive letter. - StringType::size_type start = FindDriveLetter(path_) + 2; - - StringType::size_type last_stripped = StringType::npos; - for (StringType::size_type pos = path_.length(); - pos > start && IsSeparator(path_[pos - 1]); - --pos) { - // If the string only has two separators and they're at the beginning, - // don't strip them, unless the string began with more than two separators. - if (pos != start + 1 || last_stripped == start + 2 || - !IsSeparator(path_[start - 1])) { - path_.resize(pos - 1); - last_stripped = pos; - } - } -} -} // namespace base +#include "base/files/file_path.h" + +namespace base { +using StringType = FilePath::StringType; +using StringPieceType = FilePath::StringPieceType; + +namespace { + +const FilePath::CharType kStringTerminator = FILE_PATH_LITERAL('\0'); +// If this FilePath contains a drive letter specification, returns the +// position of the last character of the drive letter specification, +// otherwise returns npos. This can only be true on Windows, when a pathname +// begins with a letter followed by a colon. On other platforms, this always +// returns npos. +StringPieceType::size_type FindDriveLetter(StringPieceType path) { +#if defined(FILE_PATH_USES_DRIVE_LETTERS) + // This is dependent on an ASCII-based character set, but that's a + // reasonable assumption. iswalpha can be too inclusive here. + if (path.length() >= 2 && path[1] == L':' && + ((path[0] >= L'A' && path[0] <= L'Z') || + (path[0] >= L'a' && path[0] <= L'z'))) { + return 1; + } +#endif // FILE_PATH_USES_DRIVE_LETTERS + return StringType::npos; +} + +bool AreAllSeparators(const StringType& input) { + for (auto it : input) { + if (!FilePath::IsSeparator(it)) + return false; + } + + return true; +} +} // namespace + +FilePath::FilePath() = default; + +FilePath::FilePath(const FilePath& that) = default; +FilePath::FilePath(FilePath&& that) noexcept = default; + +FilePath::FilePath(StringPieceType path) : path_(path) { + StringType::size_type nul_pos = path_.find(kStringTerminator); + if (nul_pos != StringType::npos) + path_.erase(nul_pos, StringType::npos); +} + +FilePath::~FilePath() = default; + +FilePath& FilePath::operator=(const FilePath& that) = default; + +FilePath& FilePath::operator=(FilePath&& that) noexcept = default; + +bool FilePath::operator==(const FilePath& that) const { +#if defined(FILE_PATH_USES_DRIVE_LETTERS) + return EqualDriveLetterCaseInsensitive(this->path_, that.path_); +#else // defined(FILE_PATH_USES_DRIVE_LETTERS) + return path_ == that.path_; +#endif // defined(FILE_PATH_USES_DRIVE_LETTERS) +} + +bool FilePath::operator!=(const FilePath& that) const { +#if defined(FILE_PATH_USES_DRIVE_LETTERS) + return !EqualDriveLetterCaseInsensitive(this->path_, that.path_); +#else // defined(FILE_PATH_USES_DRIVE_LETTERS) + return path_ != that.path_; +#endif // defined(FILE_PATH_USES_DRIVE_LETTERS) +} + +std::ostream& operator<<(std::ostream& out, const FilePath& file_path) { + return out << file_path.value(); +} + +// libgen's dirname and basename aren't guaranteed to be thread-safe and aren't +// guaranteed to not modify their input strings, and in fact are implemented +// differently in this regard on different platforms. Don't use them, but +// adhere to their behavior. +FilePath FilePath::DirName() const { + FilePath new_path(path_); + new_path.StripTrailingSeparatorsInternal(); + + // The drive letter, if any, always needs to remain in the output. If there + // is no drive letter, as will always be the case on platforms which do not + // support drive letters, letter will be npos, or -1, so the comparisons and + // resizes below using letter will still be valid. + StringType::size_type letter = FindDriveLetter(new_path.path_); + + StringType::size_type last_separator = + new_path.path_.find_last_of(kSeparators, StringType::npos, + kSeparatorsLength - 1); + if (last_separator == StringType::npos) { + // path_ is in the current directory. + new_path.path_.resize(letter + 1); + } else if (last_separator == letter + 1) { + // path_ is in the root directory. + new_path.path_.resize(letter + 2); + } else if (last_separator == letter + 2 && + IsSeparator(new_path.path_[letter + 1])) { + // path_ is in "//" (possibly with a drive letter); leave the double + // separator intact indicating alternate root. + new_path.path_.resize(letter + 3); + } else if (last_separator != 0) { + // path_ is somewhere else, trim the basename. + new_path.path_.resize(last_separator); + } + + new_path.StripTrailingSeparatorsInternal(); + if (!new_path.path_.length()) + new_path.path_ = kCurrentDirectory; + + return new_path; +} + +FilePath FilePath::BaseName() const { + FilePath new_path(path_); + new_path.StripTrailingSeparatorsInternal(); + + // The drive letter, if any, is always stripped. + StringType::size_type letter = FindDriveLetter(new_path.path_); + if (letter != StringType::npos) { + new_path.path_.erase(0, letter + 1); + } + + // Keep everything after the final separator, but if the pathname is only + // one character and it's a separator, leave it alone. + StringType::size_type last_separator = + new_path.path_.find_last_of(kSeparators, StringType::npos, + kSeparatorsLength - 1); + if (last_separator != StringType::npos && + last_separator < new_path.path_.length() - 1) { + new_path.path_.erase(0, last_separator + 1); + } + + return new_path; +} + +// static +bool FilePath::IsSeparator(CharType character) { + for (size_t i = 0; i < kSeparatorsLength - 1; ++i) { + if (character == kSeparators[i]) { + return true; + } + } + + return false; +} + +void FilePath::GetComponents(std::vector* components) const { + QUICHE_DCHECK(components); + if (!components) + return; + components->clear(); + if (value().empty()) + return; + + std::vector ret_val; + FilePath current = *this; + FilePath base; + + // Capture path components. + while (current != current.DirName()) { + base = current.BaseName(); + if (!AreAllSeparators(base.value())) + ret_val.push_back(base.value()); + current = current.DirName(); + } + + // Capture root, if any. + base = current.BaseName(); + if (!base.value().empty() && base.value() != kCurrentDirectory) + ret_val.push_back(current.BaseName().value()); + + // Capture drive letter, if any. + FilePath dir = current.DirName(); + StringType::size_type letter = FindDriveLetter(dir.value()); + if (letter != StringType::npos) { + ret_val.push_back(StringType(dir.value(), 0, letter + 1)); + } + + *components = std::vector(ret_val.rbegin(), ret_val.rend()); +} + +bool FilePath::ReferencesParent() const { + if (path_.find(kParentDirectory) == StringType::npos) { + // GetComponents is quite expensive, so avoid calling it in the majority + // of cases where there isn't a kParentDirectory anywhere in the path. + return false; + } + + std::vector components; + GetComponents(&components); + + std::vector::const_iterator it = components.begin(); + for (; it != components.end(); ++it) { + const StringType& component = *it; + // Windows has odd, undocumented behavior with path components containing + // only whitespace and . characters. So, if all we see is . and + // whitespace, then we treat any .. sequence as referencing parent. + // For simplicity we enforce this on all platforms. + if (component.find_first_not_of(FILE_PATH_LITERAL(". \n\r\t")) == + std::string::npos && + component.find(kParentDirectory) != std::string::npos) { + return true; + } + } + return false; +} + + +void FilePath::StripTrailingSeparatorsInternal() { + // If there is no drive letter, start will be 1, which will prevent stripping + // the leading separator if there is only one separator. If there is a drive + // letter, start will be set appropriately to prevent stripping the first + // separator following the drive letter, if a separator immediately follows + // the drive letter. + StringType::size_type start = FindDriveLetter(path_) + 2; + + StringType::size_type last_stripped = StringType::npos; + for (StringType::size_type pos = path_.length(); + pos > start && IsSeparator(path_[pos - 1]); + --pos) { + // If the string only has two separators and they're at the beginning, + // don't strip them, unless the string began with more than two separators. + if (pos != start + 1 || last_stripped == start + 2 || + !IsSeparator(path_[start - 1])) { + path_.resize(pos - 1); + last_stripped = pos; + } + } +} +} // namespace base diff --git a/base/files/file_path.h b/base/files/file_path.h index 1849ae73..e689da38 100644 --- a/base/files/file_path.h +++ b/base/files/file_path.h @@ -1,131 +1,131 @@ -#ifndef QUICHE_BASE_FILES_FILE_PATH_H_ -#define QUICHE_BASE_FILES_FILE_PATH_H_ - -#include -#include -#include "gquiche/common/platform/api/quiche_logging.h" -#include "googleurl/base/compiler_specific.h" - -// Macros for string literal initialization of FilePath::CharType[]. -#if defined(OS_WIN) -#define FILE_PATH_LITERAL(x) L##x -#elif defined(OS_POSIX) || defined(OS_FUCHSIA) -#define FILE_PATH_LITERAL(x) x -#endif // OS_WIN - -namespace base { - -// An abstraction to isolate users from the differences between native -// pathnames on different platforms. -class FilePath { - public: -#if defined(OS_WIN) - // On Windows, for Unicode-aware applications, native pathnames are wchar_t - // arrays encoded in UTF-16. - typedef std::wstring StringType; -#elif defined(OS_POSIX) || defined(OS_FUCHSIA) - // On most platforms, native pathnames are char arrays, and the encoding - // may or may not be specified. On Mac OS X, native pathnames are encoded - // in UTF-8. - typedef std::string StringType; -#endif // OS_WIN - - typedef std::string StringPieceType; - typedef StringType::value_type CharType; - - // Null-terminated array of separators used to separate components in - // hierarchical paths. Each character in this array is a valid separator, - // but kSeparators[0] is treated as the canonical separator and will be used - // when composing pathnames. - static const CharType kSeparators[]; - - // base::size(kSeparators). - static const size_t kSeparatorsLength; - - // A special path component meaning "this directory." - static const CharType kCurrentDirectory[]; - - // A special path component meaning "the parent directory." - static const CharType kParentDirectory[]; - - // The character used to identify a file extension. - static const CharType kExtensionSeparator; - - FilePath(); - FilePath(const FilePath& that); - explicit FilePath(StringPieceType path); - ~FilePath(); - FilePath& operator=(const FilePath& that); - - // Constructs FilePath with the contents of |that|, which is left in valid but - // unspecified state. - FilePath(FilePath&& that) noexcept; - // Replaces the contents with those of |that|, which is left in valid but - // unspecified state. - FilePath& operator=(FilePath&& that) noexcept; - - bool operator==(const FilePath& that) const; - - bool operator!=(const FilePath& that) const; - - // Required for some STL containers and operations - bool operator<(const FilePath& that) const { - return path_ < that.path_; - } - - const StringType& value() const { return path_; } - - bool empty() const { return path_.empty(); } - - void clear() { path_.clear(); } - - // Returns true if |character| is in kSeparators. - static bool IsSeparator(CharType character); - - // Returns a vector of all of the components of the provided path. It is - // equivalent to calling DirName().value() on the path's root component, - // and BaseName().value() on each child component. - // - // To make sure this is lossless so we can differentiate absolute and - // relative paths, the root slash will be included even though no other - // slashes will be. The precise behavior is: - // - // Posix: "/foo/bar" -> [ "/", "foo", "bar" ] - // Windows: "C:\foo\bar" -> [ "C:", "\\", "foo", "bar" ] - void GetComponents(std::vector* components) const; - - // Returns a FilePath corresponding to the directory containing the path - // named by this object, stripping away the file component. If this object - // only contains one component, returns a FilePath identifying - // kCurrentDirectory. If this object already refers to the root directory, - // returns a FilePath identifying the root directory. Please note that this - // doesn't resolve directory navigation, e.g. the result for "../a" is "..". - FilePath DirName() const WARN_UNUSED_RESULT; - - // Returns a FilePath corresponding to the last path component of this - // object, either a file or a directory. If this object already refers to - // the root directory, returns a FilePath identifying the root directory; - // this is the only situation in which BaseName will return an absolute path. - FilePath BaseName() const WARN_UNUSED_RESULT; - - // Returns true if this FilePath contains an attempt to reference a parent - // directory (e.g. has a path component that is ".."). - bool ReferencesParent() const; - - private: - // Remove trailing separators from this object. If the path is absolute, it - // will never be stripped any more than to refer to the absolute root - // directory, so "////" will become "/", not "". A leading pair of - // separators is never stripped, to support alternate roots. This is used to - // support UNC paths on Windows. - void StripTrailingSeparatorsInternal(); - - StringType path_; -}; - -std::ostream& operator<<(std::ostream& out, - const FilePath& file_path); - -} // namespace base - -#endif // QUICHE_BASE_FILES_FILE_PATH_H_ +#ifndef QUICHE_BASE_FILES_FILE_PATH_H_ +#define QUICHE_BASE_FILES_FILE_PATH_H_ + +#include +#include +#include "gquiche/common/platform/api/quiche_logging.h" +#include "googleurl/base/compiler_specific.h" + +// Macros for string literal initialization of FilePath::CharType[]. +#if defined(OS_WIN) +#define FILE_PATH_LITERAL(x) L##x +#elif defined(OS_POSIX) || defined(OS_FUCHSIA) +#define FILE_PATH_LITERAL(x) x +#endif // OS_WIN + +namespace base { + +// An abstraction to isolate users from the differences between native +// pathnames on different platforms. +class FilePath { + public: +#if defined(OS_WIN) + // On Windows, for Unicode-aware applications, native pathnames are wchar_t + // arrays encoded in UTF-16. + typedef std::wstring StringType; +#elif defined(OS_POSIX) || defined(OS_FUCHSIA) + // On most platforms, native pathnames are char arrays, and the encoding + // may or may not be specified. On Mac OS X, native pathnames are encoded + // in UTF-8. + typedef std::string StringType; +#endif // OS_WIN + + typedef std::string StringPieceType; + typedef StringType::value_type CharType; + + // Null-terminated array of separators used to separate components in + // hierarchical paths. Each character in this array is a valid separator, + // but kSeparators[0] is treated as the canonical separator and will be used + // when composing pathnames. + static const CharType kSeparators[]; + + // base::size(kSeparators). + static const size_t kSeparatorsLength; + + // A special path component meaning "this directory." + static const CharType kCurrentDirectory[]; + + // A special path component meaning "the parent directory." + static const CharType kParentDirectory[]; + + // The character used to identify a file extension. + static const CharType kExtensionSeparator; + + FilePath(); + FilePath(const FilePath& that); + explicit FilePath(StringPieceType path); + ~FilePath(); + FilePath& operator=(const FilePath& that); + + // Constructs FilePath with the contents of |that|, which is left in valid but + // unspecified state. + FilePath(FilePath&& that) noexcept; + // Replaces the contents with those of |that|, which is left in valid but + // unspecified state. + FilePath& operator=(FilePath&& that) noexcept; + + bool operator==(const FilePath& that) const; + + bool operator!=(const FilePath& that) const; + + // Required for some STL containers and operations + bool operator<(const FilePath& that) const { + return path_ < that.path_; + } + + const StringType& value() const { return path_; } + + bool empty() const { return path_.empty(); } + + void clear() { path_.clear(); } + + // Returns true if |character| is in kSeparators. + static bool IsSeparator(CharType character); + + // Returns a vector of all of the components of the provided path. It is + // equivalent to calling DirName().value() on the path's root component, + // and BaseName().value() on each child component. + // + // To make sure this is lossless so we can differentiate absolute and + // relative paths, the root slash will be included even though no other + // slashes will be. The precise behavior is: + // + // Posix: "/foo/bar" -> [ "/", "foo", "bar" ] + // Windows: "C:\foo\bar" -> [ "C:", "\\", "foo", "bar" ] + void GetComponents(std::vector* components) const; + + // Returns a FilePath corresponding to the directory containing the path + // named by this object, stripping away the file component. If this object + // only contains one component, returns a FilePath identifying + // kCurrentDirectory. If this object already refers to the root directory, + // returns a FilePath identifying the root directory. Please note that this + // doesn't resolve directory navigation, e.g. the result for "../a" is "..". + FilePath DirName() const WARN_UNUSED_RESULT; + + // Returns a FilePath corresponding to the last path component of this + // object, either a file or a directory. If this object already refers to + // the root directory, returns a FilePath identifying the root directory; + // this is the only situation in which BaseName will return an absolute path. + FilePath BaseName() const WARN_UNUSED_RESULT; + + // Returns true if this FilePath contains an attempt to reference a parent + // directory (e.g. has a path component that is ".."). + bool ReferencesParent() const; + + private: + // Remove trailing separators from this object. If the path is absolute, it + // will never be stripped any more than to refer to the absolute root + // directory, so "////" will become "/", not "". A leading pair of + // separators is never stripped, to support alternate roots. This is used to + // support UNC paths on Windows. + void StripTrailingSeparatorsInternal(); + + StringType path_; +}; + +std::ostream& operator<<(std::ostream& out, + const FilePath& file_path); + +} // namespace base + +#endif // QUICHE_BASE_FILES_FILE_PATH_H_ diff --git a/base/files/file_path_constants.cc b/base/files/file_path_constants.cc index 879a007d..95709133 100644 --- a/base/files/file_path_constants.cc +++ b/base/files/file_path_constants.cc @@ -1,24 +1,24 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include - -#include "base/files/file_path.h" - -namespace base { - -#if defined(FILE_PATH_USES_WIN_SEPARATORS) -const FilePath::CharType FilePath::kSeparators[] = FILE_PATH_LITERAL("\\/"); -#else // FILE_PATH_USES_WIN_SEPARATORS -const FilePath::CharType FilePath::kSeparators[] = FILE_PATH_LITERAL("/"); -#endif // FILE_PATH_USES_WIN_SEPARATORS - -const size_t FilePath::kSeparatorsLength = std::size(kSeparators); - -const FilePath::CharType FilePath::kCurrentDirectory[] = FILE_PATH_LITERAL("."); -const FilePath::CharType FilePath::kParentDirectory[] = FILE_PATH_LITERAL(".."); - -const FilePath::CharType FilePath::kExtensionSeparator = FILE_PATH_LITERAL('.'); - +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include "base/files/file_path.h" + +namespace base { + +#if defined(FILE_PATH_USES_WIN_SEPARATORS) +const FilePath::CharType FilePath::kSeparators[] = FILE_PATH_LITERAL("\\/"); +#else // FILE_PATH_USES_WIN_SEPARATORS +const FilePath::CharType FilePath::kSeparators[] = FILE_PATH_LITERAL("/"); +#endif // FILE_PATH_USES_WIN_SEPARATORS + +const size_t FilePath::kSeparatorsLength = std::size(kSeparators); + +const FilePath::CharType FilePath::kCurrentDirectory[] = FILE_PATH_LITERAL("."); +const FilePath::CharType FilePath::kParentDirectory[] = FILE_PATH_LITERAL(".."); + +const FilePath::CharType FilePath::kExtensionSeparator = FILE_PATH_LITERAL('.'); + } // namespace base \ No newline at end of file diff --git a/base/files/file_util.cc b/base/files/file_util.cc index e04c2f51..f1f39ca3 100644 --- a/base/files/file_util.cc +++ b/base/files/file_util.cc @@ -1,86 +1,86 @@ -#include -#include -#include "base/files/file_util.h" -#include "base/files/file_path.h" -#include "base/files/scoped_file.h" -#include "base/posix/eintr_wrapper.h" - -namespace base { -bool ReadFileToString(const FilePath& path, std::string* contents) { - return ReadFileToStringWithMaxSize(path, contents, - std::numeric_limits::max()); -} - -bool ReadFileToStringWithMaxSize(const FilePath& path, - std::string* contents, - size_t max_size) { - if (contents) - contents->clear(); - if (path.ReferencesParent()) - return false; - ScopedFILE file_stream(OpenFile(path, "rb")); - if (!file_stream) - return false; - return ReadStreamToStringWithMaxSize(file_stream.get(), max_size, contents); -} - -bool ReadStreamToStringWithMaxSize(FILE* stream, - size_t max_size, - std::string* contents) { - if (contents) - contents->clear(); - - // Seeking to the beginning is best-effort -- it is expected to fail for - // certain non-file stream (e.g., pipes). - HANDLE_EINTR(fseek(stream, 0, SEEK_SET)); - - // Many files have incorrect size (proc files etc). Hence, the file is read - // sequentially as opposed to a one-shot read, using file size as a hint for - // chunk size if available. - constexpr int64_t kDefaultChunkSize = 1 << 16; - int64_t chunk_size = kDefaultChunkSize - 1; - - stat_wrapper_t file_info = {}; - if (!fstat64(fileno(stream), &file_info) && file_info.st_size > 0) - chunk_size = file_info.st_size; - - // We need to attempt to read at EOF for feof flag to be set so here we - // use |chunk_size| + 1. - chunk_size = std::min(chunk_size, max_size) + 1; - - size_t bytes_read_this_pass; - size_t bytes_read_so_far = 0; - bool read_status = true; - std::string local_contents; - local_contents.resize(chunk_size); - - while ((bytes_read_this_pass = fread(&local_contents[bytes_read_so_far], 1, - chunk_size, stream)) > 0) { - if ((max_size - bytes_read_so_far) < bytes_read_this_pass) { - // Read more than max_size bytes, bail out. - bytes_read_so_far = max_size; - read_status = false; - break; - } - // In case EOF was not reached, iterate again but revert to the default - // chunk size. - if (bytes_read_so_far == 0) - chunk_size = kDefaultChunkSize; - - bytes_read_so_far += bytes_read_this_pass; - // Last fread syscall (after EOF) can be avoided via feof, which is just a - // flag check. - if (feof(stream)) - break; - local_contents.resize(bytes_read_so_far + chunk_size); - } - read_status = read_status && !ferror(stream); - if (contents) { - contents->swap(local_contents); - contents->resize(bytes_read_so_far); - } - - return read_status; -} - -} //namespace base +#include +#include +#include "base/files/file_util.h" +#include "base/files/file_path.h" +#include "base/files/scoped_file.h" +#include "base/posix/eintr_wrapper.h" + +namespace base { +bool ReadFileToString(const FilePath& path, std::string* contents) { + return ReadFileToStringWithMaxSize(path, contents, + std::numeric_limits::max()); +} + +bool ReadFileToStringWithMaxSize(const FilePath& path, + std::string* contents, + size_t max_size) { + if (contents) + contents->clear(); + if (path.ReferencesParent()) + return false; + ScopedFILE file_stream(OpenFile(path, "rb")); + if (!file_stream) + return false; + return ReadStreamToStringWithMaxSize(file_stream.get(), max_size, contents); +} + +bool ReadStreamToStringWithMaxSize(FILE* stream, + size_t max_size, + std::string* contents) { + if (contents) + contents->clear(); + + // Seeking to the beginning is best-effort -- it is expected to fail for + // certain non-file stream (e.g., pipes). + HANDLE_EINTR(fseek(stream, 0, SEEK_SET)); + + // Many files have incorrect size (proc files etc). Hence, the file is read + // sequentially as opposed to a one-shot read, using file size as a hint for + // chunk size if available. + constexpr int64_t kDefaultChunkSize = 1 << 16; + int64_t chunk_size = kDefaultChunkSize - 1; + + stat_wrapper_t file_info = {}; + if (!fstat64(fileno(stream), &file_info) && file_info.st_size > 0) + chunk_size = file_info.st_size; + + // We need to attempt to read at EOF for feof flag to be set so here we + // use |chunk_size| + 1. + chunk_size = std::min(chunk_size, max_size) + 1; + + size_t bytes_read_this_pass; + size_t bytes_read_so_far = 0; + bool read_status = true; + std::string local_contents; + local_contents.resize(chunk_size); + + while ((bytes_read_this_pass = fread(&local_contents[bytes_read_so_far], 1, + chunk_size, stream)) > 0) { + if ((max_size - bytes_read_so_far) < bytes_read_this_pass) { + // Read more than max_size bytes, bail out. + bytes_read_so_far = max_size; + read_status = false; + break; + } + // In case EOF was not reached, iterate again but revert to the default + // chunk size. + if (bytes_read_so_far == 0) + chunk_size = kDefaultChunkSize; + + bytes_read_so_far += bytes_read_this_pass; + // Last fread syscall (after EOF) can be avoided via feof, which is just a + // flag check. + if (feof(stream)) + break; + local_contents.resize(bytes_read_so_far + chunk_size); + } + read_status = read_status && !ferror(stream); + if (contents) { + contents->swap(local_contents); + contents->resize(bytes_read_so_far); + } + + return read_status; +} + +} //namespace base diff --git a/base/files/file_util.h b/base/files/file_util.h index 9eaca3be..d5309a57 100644 --- a/base/files/file_util.h +++ b/base/files/file_util.h @@ -1,52 +1,52 @@ -#ifndef QUICHE_BASE_FILES_FILE_UTIL_H_ -#define QUICHE_BASE_FILES_FILE_UTIL_H_ - -#if defined(OS_POSIX) || defined(OS_FUCHSIA) -#include -#include -#endif - -#include - -#include "base/files/file_path.h" - -namespace base { -class FilePath; - -typedef struct stat64 stat_wrapper_t; - -// Reads the file at |path| into |contents| and returns true on success and -// false on error. For security reasons, a |path| containing path traversal -// components ('..') is treated as a read error and |contents| is set to empty. -// In case of I/O error, |contents| holds the data that could be read from the -// file before the error occurred. -// |contents| may be NULL, in which case this function is useful for its side -// effect of priming the disk cache (could be used for unit tests). -bool ReadFileToString(const FilePath& path, std::string* contents); - -// Reads the file at |path| into |contents| and returns true on success and -// false on error. For security reasons, a |path| containing path traversal -// components ('..') is treated as a read error and |contents| is set to empty. -// In case of I/O error, |contents| holds the data that could be read from the -// file before the error occurred. When the file size exceeds |max_size|, the -// function returns false with |contents| holding the file truncated to -// |max_size|. -// |contents| may be NULL, in which case this function is useful for its side -// effect of priming the disk cache (could be used for unit tests). -bool ReadFileToStringWithMaxSize(const FilePath& path, - std::string* contents, - size_t max_size); - -// As ReadFileToStringWithMaxSize, but reading from an open stream after seeking -// to its start (if supported by the stream). -bool ReadStreamToStringWithMaxSize(FILE* stream, - size_t max_size, - std::string* contents); -// Wrapper for fopen-like calls. Returns non-NULL FILE* on success. The -// underlying file descriptor (POSIX) or handle (Windows) is unconditionally -// configured to not be propagated to child processes. -FILE* OpenFile(const FilePath& filename, const char* mode); - -} // namespace base - -#endif //QUICHE_BASE_FILES_FILE_UTIL_H_ +#ifndef QUICHE_BASE_FILES_FILE_UTIL_H_ +#define QUICHE_BASE_FILES_FILE_UTIL_H_ + +#if defined(OS_POSIX) || defined(OS_FUCHSIA) +#include +#include +#endif + +#include + +#include "base/files/file_path.h" + +namespace base { +class FilePath; + +typedef struct stat64 stat_wrapper_t; + +// Reads the file at |path| into |contents| and returns true on success and +// false on error. For security reasons, a |path| containing path traversal +// components ('..') is treated as a read error and |contents| is set to empty. +// In case of I/O error, |contents| holds the data that could be read from the +// file before the error occurred. +// |contents| may be NULL, in which case this function is useful for its side +// effect of priming the disk cache (could be used for unit tests). +bool ReadFileToString(const FilePath& path, std::string* contents); + +// Reads the file at |path| into |contents| and returns true on success and +// false on error. For security reasons, a |path| containing path traversal +// components ('..') is treated as a read error and |contents| is set to empty. +// In case of I/O error, |contents| holds the data that could be read from the +// file before the error occurred. When the file size exceeds |max_size|, the +// function returns false with |contents| holding the file truncated to +// |max_size|. +// |contents| may be NULL, in which case this function is useful for its side +// effect of priming the disk cache (could be used for unit tests). +bool ReadFileToStringWithMaxSize(const FilePath& path, + std::string* contents, + size_t max_size); + +// As ReadFileToStringWithMaxSize, but reading from an open stream after seeking +// to its start (if supported by the stream). +bool ReadStreamToStringWithMaxSize(FILE* stream, + size_t max_size, + std::string* contents); +// Wrapper for fopen-like calls. Returns non-NULL FILE* on success. The +// underlying file descriptor (POSIX) or handle (Windows) is unconditionally +// configured to not be propagated to child processes. +FILE* OpenFile(const FilePath& filename, const char* mode); + +} // namespace base + +#endif //QUICHE_BASE_FILES_FILE_UTIL_H_ diff --git a/base/files/file_util_posix.cc b/base/files/file_util_posix.cc index 5b9e39a7..92f3e42d 100644 --- a/base/files/file_util_posix.cc +++ b/base/files/file_util_posix.cc @@ -1,36 +1,36 @@ -#include -#include -#include -#include -#include - -#include "base/files/file_util.h" -#include "base/files/file_path.h" - -#include "gquiche/quic/platform/api/quic_logging.h" - -namespace base { - -std::string AppendModeCharacter(const char* mode, char mode_char) { - std::string result(mode); - size_t comma_pos = result.find(','); - result.insert(comma_pos == std::string::npos ? result.length() : comma_pos, 1, - mode_char); - return result; -} - -FILE* OpenFile(const FilePath& filename, const char* mode) { - // 'e' is unconditionally added below, so be sure there is not one already - // present before a comma in |mode|. - QUICHE_DCHECK( - strchr(mode, 'e') == nullptr || - (strchr(mode, ',') != nullptr && strchr(mode, 'e') > strchr(mode, ','))); - FILE* result = nullptr; - std::string mode_with_e(AppendModeCharacter(mode, 'e')); - const char* the_mode = mode_with_e.c_str(); - do { - result = fopen(filename.value().c_str(), the_mode); - } while (!result && errno == EINTR); - return result; -} -} // namespace base +#include +#include +#include +#include +#include + +#include "base/files/file_util.h" +#include "base/files/file_path.h" + +#include "gquiche/quic/platform/api/quic_logging.h" + +namespace base { + +std::string AppendModeCharacter(const char* mode, char mode_char) { + std::string result(mode); + size_t comma_pos = result.find(','); + result.insert(comma_pos == std::string::npos ? result.length() : comma_pos, 1, + mode_char); + return result; +} + +FILE* OpenFile(const FilePath& filename, const char* mode) { + // 'e' is unconditionally added below, so be sure there is not one already + // present before a comma in |mode|. + QUICHE_DCHECK( + strchr(mode, 'e') == nullptr || + (strchr(mode, ',') != nullptr && strchr(mode, 'e') > strchr(mode, ','))); + FILE* result = nullptr; + std::string mode_with_e(AppendModeCharacter(mode, 'e')); + const char* the_mode = mode_with_e.c_str(); + do { + result = fopen(filename.value().c_str(), the_mode); + } while (!result && errno == EINTR); + return result; +} +} // namespace base diff --git a/base/files/scoped_file.h b/base/files/scoped_file.h index 885e0de9..4050c8f3 100644 --- a/base/files/scoped_file.h +++ b/base/files/scoped_file.h @@ -1,26 +1,26 @@ -#ifndef QUICHE_BASE_FILES_SCOPED_FILE_H_ -#define QUICHE_BASE_FILES_SCOPED_FILE_H_ - -#include -#include - -namespace base { - -namespace internal { - -// Functor for |ScopedFILE| (below). -struct ScopedFILECloser { - inline void operator()(FILE* x) const { - if (x) - fclose(x); - } -}; - -} // namespace internal - -// Automatically closes |FILE*|s. -typedef std::unique_ptr ScopedFILE; - -} // namespace base - -#endif // QUICHE_BASE_FILES_SCOPED_FILE_H_ +#ifndef QUICHE_BASE_FILES_SCOPED_FILE_H_ +#define QUICHE_BASE_FILES_SCOPED_FILE_H_ + +#include +#include + +namespace base { + +namespace internal { + +// Functor for |ScopedFILE| (below). +struct ScopedFILECloser { + inline void operator()(FILE* x) const { + if (x) + fclose(x); + } +}; + +} // namespace internal + +// Automatically closes |FILE*|s. +typedef std::unique_ptr ScopedFILE; + +} // namespace base + +#endif // QUICHE_BASE_FILES_SCOPED_FILE_H_ diff --git a/base/posix/eintr_wrapper.h b/base/posix/eintr_wrapper.h index 4d56747a..c5101e96 100644 --- a/base/posix/eintr_wrapper.h +++ b/base/posix/eintr_wrapper.h @@ -1,68 +1,68 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -// This provides a wrapper around system calls which may be interrupted by a -// signal and return EINTR. See man 7 signal. -// To prevent long-lasting loops (which would likely be a bug, such as a signal -// that should be masked) to go unnoticed, there is a limit after which the -// caller will nonetheless see an EINTR in Debug builds. -// -// On Windows and Fuchsia, this wrapper macro does nothing because there are no -// signals. -// -// Don't wrap close calls in HANDLE_EINTR. Use IGNORE_EINTR if the return -// value of close is significant. See http://crbug.com/269623. - -#ifndef QUICHE_POSIX_EINTR_WRAPPER_H_ -#define QUICHE_POSIX_EINTR_WRAPPER_H_ - -#include "googleurl/build/build_config.h" - -#if defined(OS_POSIX) - -#include - -#if defined(NDEBUG) - -#define HANDLE_EINTR(x) ({ \ - decltype(x) eintr_wrapper_result; \ - do { \ - eintr_wrapper_result = (x); \ - } while (eintr_wrapper_result == -1 && errno == EINTR); \ - eintr_wrapper_result; \ -}) - -#else - -#define HANDLE_EINTR(x) ({ \ - int eintr_wrapper_counter = 0; \ - decltype(x) eintr_wrapper_result; \ - do { \ - eintr_wrapper_result = (x); \ - } while (eintr_wrapper_result == -1 && errno == EINTR && \ - eintr_wrapper_counter++ < 100); \ - eintr_wrapper_result; \ -}) - -#endif // NDEBUG - -#define IGNORE_EINTR(x) ({ \ - decltype(x) eintr_wrapper_result; \ - do { \ - eintr_wrapper_result = (x); \ - if (eintr_wrapper_result == -1 && errno == EINTR) { \ - eintr_wrapper_result = 0; \ - } \ - } while (0); \ - eintr_wrapper_result; \ -}) - -#else // !OS_POSIX - -#define HANDLE_EINTR(x) (x) -#define IGNORE_EINTR(x) (x) - -#endif // !OS_POSIX - +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This provides a wrapper around system calls which may be interrupted by a +// signal and return EINTR. See man 7 signal. +// To prevent long-lasting loops (which would likely be a bug, such as a signal +// that should be masked) to go unnoticed, there is a limit after which the +// caller will nonetheless see an EINTR in Debug builds. +// +// On Windows and Fuchsia, this wrapper macro does nothing because there are no +// signals. +// +// Don't wrap close calls in HANDLE_EINTR. Use IGNORE_EINTR if the return +// value of close is significant. See http://crbug.com/269623. + +#ifndef QUICHE_POSIX_EINTR_WRAPPER_H_ +#define QUICHE_POSIX_EINTR_WRAPPER_H_ + +#include "googleurl/build/build_config.h" + +#if defined(OS_POSIX) + +#include + +#if defined(NDEBUG) + +#define HANDLE_EINTR(x) ({ \ + decltype(x) eintr_wrapper_result; \ + do { \ + eintr_wrapper_result = (x); \ + } while (eintr_wrapper_result == -1 && errno == EINTR); \ + eintr_wrapper_result; \ +}) + +#else + +#define HANDLE_EINTR(x) ({ \ + int eintr_wrapper_counter = 0; \ + decltype(x) eintr_wrapper_result; \ + do { \ + eintr_wrapper_result = (x); \ + } while (eintr_wrapper_result == -1 && errno == EINTR && \ + eintr_wrapper_counter++ < 100); \ + eintr_wrapper_result; \ +}) + +#endif // NDEBUG + +#define IGNORE_EINTR(x) ({ \ + decltype(x) eintr_wrapper_result; \ + do { \ + eintr_wrapper_result = (x); \ + if (eintr_wrapper_result == -1 && errno == EINTR) { \ + eintr_wrapper_result = 0; \ + } \ + } while (0); \ + eintr_wrapper_result; \ +}) + +#else // !OS_POSIX + +#define HANDLE_EINTR(x) (x) +#define IGNORE_EINTR(x) (x) + +#endif // !OS_POSIX + #endif // QUICHE_POSIX_EINTR_WRAPPER_H_ \ No newline at end of file diff --git a/base/sinks/customize_file_helper-inl.h b/base/sinks/customize_file_helper-inl.h new file mode 100644 index 00000000..909f3189 --- /dev/null +++ b/base/sinks/customize_file_helper-inl.h @@ -0,0 +1,89 @@ +#pragma once + +#ifndef SPDLOG_HEADER_ONLY +#include "spdlog/details/customize_file_helper.h" +#endif + +#include "spdlog/details/os.h" +#include "spdlog/common.h" + +#include +#include +#include +#include +#include +#include + +namespace spdlog { +namespace details { + +SPDLOG_INLINE customize_file_helper::~customize_file_helper() { + close(); +} + +SPDLOG_INLINE void customize_file_helper::open(const filename_t &fname, bool truncate) { + close(); + filename_ = fname; + auto *mode = truncate ? SPDLOG_FILENAME_T("wb") : SPDLOG_FILENAME_T("ab"); + + for (int tries = 0; tries < open_tries_; ++tries) { + // create containing folder if not exists already. + os::create_dir(os::dir_name(fname)); + if (!os::fopen_s(&fd_, fname, mode)) { + return; + } + details::os::sleep_for_millis(open_interval_); + } + + throw_spdlog_ex("Failed opening file " + os::filename_to_str(filename_) + " for writing", errno); +} + +SPDLOG_INLINE void customize_file_helper::reopen(bool truncate) { + if (filename_.empty()) { + throw_spdlog_ex("Failed re opening file - was not opened before"); + } + this->open(filename_, truncate); +} + +SPDLOG_INLINE void customize_file_helper::flush() { + std::fflush(fd_); +} + +SPDLOG_INLINE void customize_file_helper::close() { + if (fd_ != nullptr) { + std::fclose(fd_); + fd_ = nullptr; + } +} + +SPDLOG_INLINE void customize_file_helper::write(const memory_buf_t &buf, size_t pos) { + size_t msg_size = buf.size(); + if(msg_size - 1 < pos) { + throw_spdlog_ex("buffer pos error"); + return; + } + + msg_size = msg_size - pos; + auto data = buf.data() + pos; + if (std::fwrite(data, 1, msg_size, fd_) != msg_size) { + throw_spdlog_ex("Failed writing to file " + details::os::filename_to_str(filename_), errno); + } +} + +SPDLOG_INLINE void customize_file_helper::write(const std::string &buf) { + size_t msg_size = buf.size(); + auto data = buf.data(); + if (std::fwrite(data, 1, msg_size, fd_) != msg_size) { + throw_spdlog_ex("Failed writing to file " + os::filename_to_str(filename_), errno); + } +} + +SPDLOG_INLINE size_t customize_file_helper::size() const { + if (fd_ == nullptr) { + throw_spdlog_ex("Cannot use size() on closed file " + os::filename_to_str(filename_)); + } + return os::filesize(fd_); +} + +} // namespace details +} // namespace spdlog diff --git a/base/sinks/customize_file_helper.h b/base/sinks/customize_file_helper.h new file mode 100644 index 00000000..f6a5c45c --- /dev/null +++ b/base/sinks/customize_file_helper.h @@ -0,0 +1,37 @@ +#pragma once + +#include "spdlog/common.h" +#include +#include + +namespace spdlog { +namespace details { + +class SPDLOG_API customize_file_helper { +public: + explicit customize_file_helper() = default; + + customize_file_helper(const customize_file_helper &) = delete; + customize_file_helper &operator=(const customize_file_helper &) = delete; + ~customize_file_helper(); + + void open(const filename_t &fname, bool truncate = false); + void reopen(bool truncate); + void flush(); + void close(); + void write(const memory_buf_t &buf, size_t pos); + void write(const std::string &buf); + size_t size() const; + +private: + const int open_tries_ = 5; + const int open_interval_ = 10; + std::FILE *fd_{nullptr}; + filename_t filename_; +}; +} // namespace details +} // namespace spdlog + +#ifdef SPDLOG_HEADER_ONLY +#include "customize_file_helper-inl.h" +#endif \ No newline at end of file diff --git a/base/sinks/sequence_file_sink-inl.h b/base/sinks/sequence_file_sink-inl.h new file mode 100644 index 00000000..1ec88b95 --- /dev/null +++ b/base/sinks/sequence_file_sink-inl.h @@ -0,0 +1,278 @@ +#pragma once + +#ifndef SPDLOG_HEADER_ONLY +#include "base/sinks/sequence_file_sink.h" +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "spdlog/common.h" +#include "spdlog/details/file_helper.h" +#include "spdlog/details/null_mutex.h" +#include "spdlog/fmt/fmt.h" +namespace spdlog { +namespace sinks { + +template + +SPDLOG_INLINE sequence_file_sink::sequence_file_sink( + filename_t base_filename, filename_t final_filename, std::size_t max_size, std::size_t max_files, std::string metadata_head, + std::string metadata_tail, std::shared_ptr last_summary, std::shared_ptr pre_summary, std::shared_ptr lock, std::size_t free_file_number) + : base_filename_(std::move(base_filename)), + final_filename_(std::move(final_filename)), + max_size_(max_size), + max_files_(max_files), + free_file_number_(free_file_number), + metadata_head_(metadata_head), + metadata_tail_(metadata_tail), + last_summary_(last_summary), + pre_summary_(pre_summary), + lock_(lock){ + //file open + file_helper_.open(calc_filename(base_filename_, 0),true); + current_size_ = file_helper_.size(); // expensive. called only once + + is_first_event_ = false; + metadata_head_size_ = metadata_head_.size(); + metadata_tail_size_ = metadata_tail_.size(); +} + +// calc filename according to index and file extension if exists. +// e.g. calc_filename("logs/mylog.txt, 3) => "logs/mylog_3.txt". +// e.g. calc_filename("logs/mylog.txt, 0) => "logs/mylog.txt". +template +SPDLOG_INLINE filename_t sequence_file_sink::calc_filename(const filename_t &filename, std::size_t index) { + if (index == 0u) { + return filename; + } + + filename_t basename, ext; + std::tie(basename, ext) = details::file_helper::split_by_extension(filename); + return fmt::format(SPDLOG_FILENAME_T("{}_{}{}"), basename, index, ext); +} + +template +SPDLOG_INLINE void sequence_file_sink::sink_it_(const details::log_msg &msg) { + memory_buf_t formatted; + base_sink::formatter_->format(msg, formatted); + + size_t formatted_size = formatted.size(); + char begin_of_msg = formatted[0]; + + if (begin_of_msg == ']') { + file_helper_.write(formatted, 0); + current_size_ += formatted_size; + file_helper_.close(); + return; + } + + if (begin_of_msg == '{') { + file_helper_.write(formatted, 0); + current_size_ += formatted_size; + is_first_event_ = true; + } else if (begin_of_msg == ',') { + if(is_first_event_) { + file_helper_.write(formatted, 1); + current_size_ += formatted_size-1; + is_first_event_ = false; + } else { + file_helper_.write(formatted, 0); + current_size_ += formatted_size; + } + } else { // begin_of_msg == " " + file_helper_.write(formatted, 0); + current_size_ += formatted_size; + is_first_event_ = false; + } + + if (current_size_ > max_size_) { + std::string summary_data = currentSummary(); + file_helper_.write(metadata_tail_); + file_helper_.write(summary_data); + flush_(); + + //split file + sequence_(); + file_helper_.write(metadata_head_); + current_size_ = metadata_head_size_; + + free_file_number_ = (free_file_number_ + 1) % max_files_; + if (free_file_number_ == 0) + free_file_number_ = 1; + } +} + +template +SPDLOG_INLINE void sequence_file_sink::flush_() { + file_helper_.flush(); +} + +template +SPDLOG_INLINE std::string sequence_file_sink::currentSummary() { + std::string msg; + absl::StrAppend(&msg, "\"summary\":"); + + lock_->WriterLock(); + Document pre_summary; + pre_summary.CopyFrom(*pre_summary_, pre_summary.GetAllocator()); + + Document last_summary,tmp_summary; + last_summary.CopyFrom(*last_summary_, last_summary.GetAllocator()); + tmp_summary.CopyFrom(*last_summary_, tmp_summary.GetAllocator()); + *pre_summary_ = std::move(tmp_summary); + lock_->WriterUnlock(); + + Document summary; + summary.SetObject(); + Document::AllocatorType& summary_allocator = summary.GetAllocator(); + + Value key,value; + summary.AddMember("trace_count", last_summary["trace_count"].GetInt64(), summary_allocator); + summary.AddMember("max_duration", last_summary["max_duration"].GetInt64(), summary_allocator); + summary.AddMember("total_event_count", last_summary["total_event_count"].GetInt64() - pre_summary["total_event_count"].GetInt64(), summary_allocator); + + value.SetObject(); + value = std::move(last_summary["report_summary"]); + value["total_bytes_sent"] = value["total_bytes_sent"].GetInt64() - pre_summary["report_summary"]["total_bytes_sent"].GetInt64(); + value["total_bytes_recvd"] = value["total_bytes_recvd"].GetInt64() - pre_summary["report_summary"]["total_bytes_recvd"].GetInt64(); + value["total_packets_recvd"] = value["total_packets_recvd"].GetInt64() - pre_summary["report_summary"]["total_packets_recvd"].GetInt64(); + value["total_packets_sent"] = value["total_packets_sent"].GetInt64() - pre_summary["report_summary"]["total_packets_sent"].GetInt64(); + value["total_packets_lost"] = value["total_packets_lost"].GetInt64() - pre_summary["report_summary"]["total_packets_lost"].GetInt64(); + if(value["congestion_type"] != "Cubic") { + value.AddMember("pre_startup_duration", value["total_startup_duration"].GetDouble() - pre_summary["report_summary"]["total_startup_duration"].GetDouble(), summary_allocator); + value.AddMember("pre_drain_duration", value["total_drain_duration"].GetDouble() - pre_summary["report_summary"]["total_drain_duration"].GetDouble(), summary_allocator); + value.AddMember("pre_probebw_duration", value["total_probebw_duration"].GetDouble() - pre_summary["report_summary"]["total_probebw_duration"].GetDouble(), summary_allocator); + value.AddMember("pre_probertt_duration", value["total_probertt_duration"].GetDouble() - pre_summary["report_summary"]["total_probertt_duration"].GetDouble(), summary_allocator); + } + summary.AddMember("report_summary", value, summary_allocator); + + std::unordered_map last_stream; + std::unordered_map pre_stream; + auto lastStream_arr = last_summary["stream_map"].GetArray(); + auto preStream_arr = pre_summary["stream_map"].GetArray(); + for(auto it = lastStream_arr.Begin(); it != lastStream_arr.End(); it++) { + last_stream.insert({(*it)[0].GetInt64(), (*it)[1].GetInt64()}); + } + for(auto it = preStream_arr.Begin(); it != preStream_arr.End(); it++) { + pre_stream.insert({(*it)[0].GetInt64(), (*it)[1].GetInt64()}); + } + std::unordered_set res; + for(auto it = last_stream.begin(); it!= last_stream.end(); it++) { + res.insert(it->first); + } + for(auto it1 = pre_stream.begin(); it1 != pre_stream.end(); it1++) { + auto it2 = last_stream.find(it1->first); + if(it2 == last_stream.end()) { + continue; + } + size_t frame_data = it2->second - it1->second; + if(frame_data == 0) { + res.erase(it2->first); + } + } +#ifdef QLOG_FOR_QBONE + value.SetArray(); + for(auto it = res.begin(); it != res.end(); it++) { + value.PushBack(*it, summary_allocator); + } + summary.AddMember("stream_map", value, summary_allocator); +#else + std::unordered_map> last_map; + auto lastMap_arr = last_summary["uri_map"].GetArray(); + for(auto it = lastMap_arr.Begin(); it != lastMap_arr.End(); it++) { + std::vector temp_vector; + for(auto vector_it =(*it)[1].GetArray().Begin(); vector_it!= (*it)[1].GetArray().End(); vector_it++) { + if(res.find(vector_it->GetInt64()) == res.end()) { + continue; + } + temp_vector.push_back(vector_it->GetInt64()); + } + if(temp_vector.size() == 0 ){ + continue; + } + last_map.insert({(*it)[0].GetString(), temp_vector}); + } + + Value uriMap_value; + uriMap_value.SetArray(); + for (auto it = last_map.begin(); it != last_map.end(); it++) { + size_t current_size = (*it).second.size(); + Value temp_value; + temp_value.SetArray(); + for (size_t j = 0; j < current_size; j++) { + temp_value.PushBack(((*it).second)[j], summary_allocator); + } + + value.SetArray(); + value.PushBack(Value(((*it).first).c_str(), summary_allocator).Move(), summary_allocator); + value.PushBack(temp_value, summary_allocator); + uriMap_value.PushBack(value, summary_allocator); + } + summary.AddMember("uri_map", uriMap_value, summary_allocator); +#endif + StringBuffer buffer; + Writer writer(buffer); + summary.Accept(writer); + std::string str = buffer.GetString(); + absl::StrAppend(&msg, str, "}"); + return msg; +} + +// sequence move file : src -> target +// $Path/tmp/Cid/Cid.qlog -> $Path/Cid/Cid_1.qlog +// $Path/tmp/Cid/Cid.qlog -> $Path/Cid/Cid_2.qlog +// $Path/tmp/Cid/Cid.qlog -> $Path/Cid/Cid_3.qlog +// ... +// $Path/tmp/Cid/Cid.qlog -> $Path/Cid/Cid_(max_size_-1).qlog +// $Path/tmp/Cid/Cid.qlog -> $Path/Cid/Cid_1.qlog +template +SPDLOG_INLINE void sequence_file_sink::sequence_() { + using details::os::filename_to_str; + using details::os::path_exists; + + //file close; + file_helper_.close(); + + filename_t src = calc_filename(base_filename_, 0); + if (!path_exists(src)) { + throw_spdlog_ex("file path " + details::os::filename_to_str(base_filename_) + " not exist", errno); + details::os::create_dir(details::os::dir_name(src)); + file_helper_.reopen(true); + is_first_event_ = true; + return; + } + filename_t target = calc_filename(final_filename_, free_file_number_); + if (!path_exists(target)) { + details::os::create_dir(details::os::dir_name(target)); + } + if (!rename_file_(src, target)) { + details::os::sleep_for_millis(100); + if (!rename_file_(src, target)) { + // truncate the log file anyway to prevent it to grow beyond its limit! + file_helper_.reopen(true); + current_size_ = 0; + throw_spdlog_ex("sequence_file_sink: failed renaming " + filename_to_str(src) + " to " + filename_to_str(target), errno); + } + } + file_helper_.reopen(true); + is_first_event_ = true; +} + +// delete the target if exists, and rename the src file to target +// return true on success, false otherwise. +template +SPDLOG_INLINE bool sequence_file_sink::rename_file_(const filename_t &src_filename, const filename_t &target_filename) { + // try to delete the target file in case it already exists. + (void)details::os::remove(target_filename); + return details::os::rename(src_filename, target_filename) == 0; +} + +} // namespace sinks +} // namespace spdlog + diff --git a/base/sinks/sequence_file_sink.h b/base/sinks/sequence_file_sink.h new file mode 100644 index 00000000..74e07456 --- /dev/null +++ b/base/sinks/sequence_file_sink.h @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "customize_file_helper.h" +#include "base/bvc-qlog/src/qlogger_types.h" +#include "gquiche/quic/platform/api/quic_mutex.h" + +#include "spdlog/sinks/base_sink.h" +#include "spdlog/details/file_helper.h" +#include "spdlog/details/null_mutex.h" +#include "spdlog/details/synchronous_factory.h" +#include "absl/strings/str_cat.h" +#include "rapidjson/document.h" +#include "rapidjson/writer.h" +#include "rapidjson/stringbuffer.h" + +using quic::QuicMutex; +namespace spdlog { +namespace sinks { + +using namespace rapidjson; +// +// sequence file sink based on size +// +template +class sequence_file_sink final : public base_sink { +public: + sequence_file_sink(filename_t base_filename, + filename_t final_filename, + std::size_t max_size, + std::size_t max_files, + std::string metadata_head, + std::string metadata_tail, + std::shared_ptr last_summary, + std::shared_ptr pre_summary, + std::shared_ptr lock, + std::size_t free_file_number = 1 ); + static filename_t calc_filename(const filename_t &filename, std::size_t index); + ~sequence_file_sink() = default; + +protected: + void sink_it_(const details::log_msg &msg) override; + void flush_() override; + +private: + void sequence_(); + std::string currentSummary(); + + // delete the target if exists, and rename the src file to target + // return true on success, false otherwise. + bool rename_file_(const filename_t &src_filename, const filename_t &target_filename); + + filename_t base_filename_; + filename_t final_filename_; + std::size_t max_size_; + std::size_t max_files_; + std::size_t free_file_number_; + std::size_t current_size_; + + std::string metadata_head_; + std::string metadata_tail_; + size_t metadata_head_size_; + size_t metadata_tail_size_; + + std::shared_ptr last_summary_; + std::shared_ptr pre_summary_; + std::shared_ptr lock_; + + details::customize_file_helper file_helper_; + + //std::FILE *fd_{nullptr}; + + bool is_first_event_; +}; + +using sequence_file_sink_mt = sequence_file_sink; +using sequence_file_sink_st = sequence_file_sink; + +} // namespace sinks + +// +// factory functions +// +template +inline std::shared_ptr sequence_logger_mt( + const std::string &logger_name, const filename_t &filename, size_t max_file_size, size_t max_files, bool rotate_on_open = false) { + return Factory::template create(logger_name, filename, max_file_size, max_files, rotate_on_open); +} + +template +inline std::shared_ptr sequence_logger_st( + const std::string &logger_name, const filename_t &filename, size_t max_file_size, size_t max_files, bool rotate_on_open = false) { + return Factory::template create(logger_name, filename, max_file_size, max_files, rotate_on_open); +} + +} // namespace spdlog + +#ifdef SPDLOG_HEADER_ONLY +#include "base/sinks/sequence_file_sink-inl.h" +#endif + diff --git a/base/strings/stringprintf.cc b/base/strings/stringprintf.cc index 9bb08af7..b1d30499 100644 --- a/base/strings/stringprintf.cc +++ b/base/strings/stringprintf.cc @@ -1,135 +1,135 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "base/strings/stringprintf.h" - -#include -#include - -#include - -#include "polyfills/base/logging.h" -#include "base/stl_util.h" -#include "base/strings/string_util.h" -#include "base/strings/utf_string_conversions.h" -#include "build/build_config.h" - -namespace base { - -namespace { - -class ScopedClearLastError { - public: - ScopedClearLastError() : last_errno_(errno) { errno = 0; } - ~ScopedClearLastError() { errno = last_errno_; } - - private: - const int last_errno_; - - DISALLOW_COPY_AND_ASSIGN(ScopedClearLastError); -}; - -// Overloaded wrappers around vsnprintf and vswprintf. The buf_size parameter -// is the size of the buffer. These return the number of characters in the -// formatted string excluding the NUL terminator. If the buffer is not -// large enough to accommodate the formatted string without truncation, they -// return the number of characters that would be in the fully-formatted string -// (vsnprintf, and vswprintf on Windows), or -1 (vswprintf on POSIX platforms). -inline int vsnprintfT(char* buffer, - size_t buf_size, - const char* format, - va_list argptr) { - return gurl_base::vsnprintf(buffer, buf_size, format, argptr); -} - -// Templatized backend for StringPrintF/StringAppendF. This does not finalize -// the va_list, the caller is expected to do that. -template -static void StringAppendVT(std::basic_string* dst, - const CharT* format, - va_list ap) { - // First try with a small fixed size buffer. - // This buffer size should be kept in sync with StringUtilTest.GrowBoundary - // and StringUtilTest.StringPrintfBounds. - CharT stack_buf[1024]; - - va_list ap_copy; - va_copy(ap_copy, ap); - - ScopedClearLastError last_error; - int result = vsnprintfT(stack_buf, gurl_base::size(stack_buf), format, ap_copy); - va_end(ap_copy); - - if (result >= 0 && result < static_cast(gurl_base::size(stack_buf))) { - // It fit. - dst->append(stack_buf, result); - return; - } - - // Repeatedly increase buffer size until it fits. - int mem_length = gurl_base::size(stack_buf); - while (true) { - if (result < 0) { - if (errno != 0 && errno != EOVERFLOW) - return; - // Try doubling the buffer size. - mem_length *= 2; - } else { - // We need exactly "result + 1" characters. - mem_length = result + 1; - } - - if (mem_length > 32 * 1024 * 1024) { - // That should be plenty, don't try anything larger. This protects - // against huge allocations when using vsnprintfT implementations that - // return -1 for reasons other than overflow without setting errno. - GURL_DLOG(WARNING) << "Unable to printf the requested string due to size."; - return; - } - - std::vector mem_buf(mem_length); - - // NOTE: You can only use a va_list once. Since we're in a while loop, we - // need to make a new copy each time so we don't use up the original. - va_copy(ap_copy, ap); - result = vsnprintfT(&mem_buf[0], mem_length, format, ap_copy); - va_end(ap_copy); - - if ((result >= 0) && (result < mem_length)) { - // It fit. - dst->append(&mem_buf[0], result); - return; - } - } -} - -} // namespace - -std::string StringPrintV(const char* format, va_list ap) { - std::string result; - StringAppendV(&result, format, ap); - return result; -} - -const std::string& SStringPrintf(std::string* dst, const char* format, ...) { - va_list ap; - va_start(ap, format); - dst->clear(); - StringAppendV(dst, format, ap); - va_end(ap); - return *dst; -} - -void StringAppendF(std::string* dst, const char* format, ...) { - va_list ap; - va_start(ap, format); - StringAppendV(dst, format, ap); - va_end(ap); -} - -void StringAppendV(std::string* dst, const char* format, va_list ap) { - StringAppendVT(dst, format, ap); -} - -} // namespace base +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/strings/stringprintf.h" + +#include +#include + +#include + +#include "polyfills/base/logging.h" +#include "base/stl_util.h" +#include "base/strings/string_util.h" +#include "base/strings/utf_string_conversions.h" +#include "build/build_config.h" + +namespace base { + +namespace { + +class ScopedClearLastError { + public: + ScopedClearLastError() : last_errno_(errno) { errno = 0; } + ~ScopedClearLastError() { errno = last_errno_; } + + private: + const int last_errno_; + + DISALLOW_COPY_AND_ASSIGN(ScopedClearLastError); +}; + +// Overloaded wrappers around vsnprintf and vswprintf. The buf_size parameter +// is the size of the buffer. These return the number of characters in the +// formatted string excluding the NUL terminator. If the buffer is not +// large enough to accommodate the formatted string without truncation, they +// return the number of characters that would be in the fully-formatted string +// (vsnprintf, and vswprintf on Windows), or -1 (vswprintf on POSIX platforms). +inline int vsnprintfT(char* buffer, + size_t buf_size, + const char* format, + va_list argptr) { + return gurl_base::vsnprintf(buffer, buf_size, format, argptr); +} + +// Templatized backend for StringPrintF/StringAppendF. This does not finalize +// the va_list, the caller is expected to do that. +template +static void StringAppendVT(std::basic_string* dst, + const CharT* format, + va_list ap) { + // First try with a small fixed size buffer. + // This buffer size should be kept in sync with StringUtilTest.GrowBoundary + // and StringUtilTest.StringPrintfBounds. + CharT stack_buf[1024]; + + va_list ap_copy; + va_copy(ap_copy, ap); + + ScopedClearLastError last_error; + int result = vsnprintfT(stack_buf, gurl_base::size(stack_buf), format, ap_copy); + va_end(ap_copy); + + if (result >= 0 && result < static_cast(gurl_base::size(stack_buf))) { + // It fit. + dst->append(stack_buf, result); + return; + } + + // Repeatedly increase buffer size until it fits. + int mem_length = gurl_base::size(stack_buf); + while (true) { + if (result < 0) { + if (errno != 0 && errno != EOVERFLOW) + return; + // Try doubling the buffer size. + mem_length *= 2; + } else { + // We need exactly "result + 1" characters. + mem_length = result + 1; + } + + if (mem_length > 32 * 1024 * 1024) { + // That should be plenty, don't try anything larger. This protects + // against huge allocations when using vsnprintfT implementations that + // return -1 for reasons other than overflow without setting errno. + GURL_DLOG(WARNING) << "Unable to printf the requested string due to size."; + return; + } + + std::vector mem_buf(mem_length); + + // NOTE: You can only use a va_list once. Since we're in a while loop, we + // need to make a new copy each time so we don't use up the original. + va_copy(ap_copy, ap); + result = vsnprintfT(&mem_buf[0], mem_length, format, ap_copy); + va_end(ap_copy); + + if ((result >= 0) && (result < mem_length)) { + // It fit. + dst->append(&mem_buf[0], result); + return; + } + } +} + +} // namespace + +std::string StringPrintV(const char* format, va_list ap) { + std::string result; + StringAppendV(&result, format, ap); + return result; +} + +const std::string& SStringPrintf(std::string* dst, const char* format, ...) { + va_list ap; + va_start(ap, format); + dst->clear(); + StringAppendV(dst, format, ap); + va_end(ap); + return *dst; +} + +void StringAppendF(std::string* dst, const char* format, ...) { + va_list ap; + va_start(ap, format); + StringAppendV(dst, format, ap); + va_end(ap); +} + +void StringAppendV(std::string* dst, const char* format, va_list ap) { + StringAppendVT(dst, format, ap); +} + +} // namespace base diff --git a/base/strings/stringprintf.h b/base/strings/stringprintf.h index 9fd83ed7..f8358e68 100644 --- a/base/strings/stringprintf.h +++ b/base/strings/stringprintf.h @@ -1,42 +1,42 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#ifndef QUICHE_STRINGS_STRINGPRINTF_H_ -#define QUICHE_STRINGS_STRINGPRINTF_H_ - -#include -#include // va_list - -#include - -#include "googleurl/polyfills/base/base_export.h" -#include "googleurl/base/compiler_specific.h" -#include "googleurl/base/macros.h" -#include "googleurl/build/build_config.h" - -#include "base/strings/stringprintf.h" - -namespace base { - -// Return a C++ string given vprintf-like input. -std::string StringPrintV(const char* format, va_list ap) - PRINTF_FORMAT(1, 0) WARN_UNUSED_RESULT; - -// Store result into a supplied string and return it. -const std::string& SStringPrintf(std::string* dst, - const char* format, - ...) PRINTF_FORMAT(2, 3); - -// Append result to a supplied string. -void StringAppendF(std::string* dst, const char* format, ...) - PRINTF_FORMAT(2, 3); - -// Lower-level routine that takes a va_list and appends to a specified -// string. All other routines are just convenience wrappers around it. -void StringAppendV(std::string* dst, const char* format, va_list ap) - PRINTF_FORMAT(2, 0); - -} // namespace base - -#endif // QUICHE_STRINGS_STRINGPRINTF_H_ +// Copyright 2013 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_STRINGS_STRINGPRINTF_H_ +#define QUICHE_STRINGS_STRINGPRINTF_H_ + +#include +#include // va_list + +#include + +#include "googleurl/polyfills/base/base_export.h" +#include "googleurl/base/compiler_specific.h" +#include "googleurl/base/macros.h" +#include "googleurl/build/build_config.h" + +#include "base/strings/stringprintf.h" + +namespace base { + +// Return a C++ string given vprintf-like input. +std::string StringPrintV(const char* format, va_list ap) + PRINTF_FORMAT(1, 0) WARN_UNUSED_RESULT; + +// Store result into a supplied string and return it. +const std::string& SStringPrintf(std::string* dst, + const char* format, + ...) PRINTF_FORMAT(2, 3); + +// Append result to a supplied string. +void StringAppendF(std::string* dst, const char* format, ...) + PRINTF_FORMAT(2, 3); + +// Lower-level routine that takes a va_list and appends to a specified +// string. All other routines are just convenience wrappers around it. +void StringAppendV(std::string* dst, const char* format, va_list ap) + PRINTF_FORMAT(2, 0); + +} // namespace base + +#endif // QUICHE_STRINGS_STRINGPRINTF_H_ diff --git a/gquiche/common/platform/api/quiche_bug_tracker.h b/gquiche/common/platform/api/quiche_bug_tracker.h index c5c42260..8598f3dc 100644 --- a/gquiche/common/platform/api/quiche_bug_tracker.h +++ b/gquiche/common/platform/api/quiche_bug_tracker.h @@ -5,7 +5,7 @@ #ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_BUG_TRACKER_H_ #define QUICHE_COMMON_PLATFORM_API_QUICHE_BUG_TRACKER_H_ -#include "platform/quiche_platform_impl/quiche_bug_tracker_impl.h" +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h" #define QUICHE_BUG QUICHE_BUG_IMPL #define QUICHE_BUG_IF QUICHE_BUG_IF_IMPL diff --git a/gquiche/common/platform/api/quiche_export.h b/gquiche/common/platform/api/quiche_export.h index ef580779..690901bd 100644 --- a/gquiche/common/platform/api/quiche_export.h +++ b/gquiche/common/platform/api/quiche_export.h @@ -5,7 +5,7 @@ #ifndef THIRD_PARTY_QUICHE_PLATFORM_API_QUICHE_EXPORT_H_ #define THIRD_PARTY_QUICHE_PLATFORM_API_QUICHE_EXPORT_H_ -#include "platform/quiche_platform_impl/quiche_export_impl.h" +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h" // QUICHE_EXPORT is not meant to be used. #define QUICHE_EXPORT QUICHE_EXPORT_IMPL diff --git a/gquiche/common/platform/api/quiche_file_utils.cc b/gquiche/common/platform/api/quiche_file_utils.cc new file mode 100644 index 00000000..63bfb2f3 --- /dev/null +++ b/gquiche/common/platform/api/quiche_file_utils.cc @@ -0,0 +1,51 @@ +#include "gquiche/common/platform/api/quiche_file_utils.h" + +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h" + +namespace quiche { + +std::string JoinPath(absl::string_view a, absl::string_view b) { + return JoinPathImpl(a, b); +} + +absl::optional ReadFileContents(absl::string_view file) { + return ReadFileContentsImpl(file); +} + +bool EnumerateDirectory(absl::string_view path, + std::vector& directories, + std::vector& files) { + return EnumerateDirectoryImpl(path, directories, files); +} + +bool EnumerateDirectoryRecursivelyInner(absl::string_view path, + int recursion_limit, + std::vector& files) { + if (recursion_limit < 0) { + return false; + } + + std::vector local_files; + std::vector directories; + if (!EnumerateDirectory(path, directories, local_files)) { + return false; + } + for (const std::string& directory : directories) { + if (!EnumerateDirectoryRecursivelyInner(JoinPath(path, directory), + recursion_limit - 1, files)) { + return false; + } + } + for (const std::string& file : local_files) { + files.push_back(JoinPath(path, file)); + } + return true; +} + +bool EnumerateDirectoryRecursively(absl::string_view path, + std::vector& files) { + constexpr int kRecursionLimit = 20; + return EnumerateDirectoryRecursivelyInner(path, kRecursionLimit, files); +} + +} // namespace quiche diff --git a/gquiche/common/platform/api/quiche_file_utils.h b/gquiche/common/platform/api/quiche_file_utils.h new file mode 100644 index 00000000..47723d19 --- /dev/null +++ b/gquiche/common/platform/api/quiche_file_utils.h @@ -0,0 +1,40 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This header contains basic filesystem functions for use in unit tests and CLI +// tools. Note that those are not 100% suitable for production use, as in, they +// might be prone to race conditions and not always handle non-ASCII filenames +// correctly. +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_FILE_UTILS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_FILE_UTILS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +namespace quiche { + +// Join two paths in a platform-specific way. Returns |a| if |b| is empty, and +// vice versa. +std::string JoinPath(absl::string_view a, absl::string_view b); + +// Reads the entire file into the memory. +absl::optional ReadFileContents(absl::string_view file); + +// Lists all files and directories in the directory specified by |path|. Returns +// true on success, false on failure. +bool EnumerateDirectory(absl::string_view path, + std::vector& directories, + std::vector& files); + +// Recursively enumerates all of the files in the directory and all of the +// internal subdirectories. Has a fairly small recursion limit. +bool EnumerateDirectoryRecursively(absl::string_view path, + std::vector& files); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_FILE_UTILS_H_ diff --git a/gquiche/common/platform/api/quiche_file_utils_test.cc b/gquiche/common/platform/api/quiche_file_utils_test.cc new file mode 100644 index 00000000..146f4430 --- /dev/null +++ b/gquiche/common/platform/api/quiche_file_utils_test.cc @@ -0,0 +1,86 @@ +#include "gquiche/common/platform/api/quiche_file_utils.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { +namespace { + +using testing::UnorderedElementsAre; +using testing::UnorderedElementsAreArray; + +TEST(QuicheFileUtilsTest, ReadFileContents) { + std::string path = absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/testfile"); + absl::optional contents = ReadFileContents(path); + ASSERT_TRUE(contents.has_value()); + EXPECT_EQ(*contents, "This is a test file."); +} + +TEST(QuicheFileUtilsTest, ReadFileContentsFileNotFound) { + std::string path = + absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/file-that-does-not-exist"); + absl::optional contents = ReadFileContents(path); + EXPECT_FALSE(contents.has_value()); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectory) { + std::string path = + absl::StrCat(QuicheGetCommonSourcePath(), "/platform/api/testdir"); + std::vector dirs; + std::vector files; + bool success = EnumerateDirectory(path, dirs, files); + EXPECT_TRUE(success); + EXPECT_THAT(files, UnorderedElementsAre("testfile", "README.md")); + EXPECT_THAT(dirs, UnorderedElementsAre("a")); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectoryNoSuchDirectory) { + std::string path = absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/no-such-directory"); + std::vector dirs; + std::vector files; + bool success = EnumerateDirectory(path, dirs, files); + EXPECT_FALSE(success); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectoryNotADirectory) { + std::string path = absl::StrCat(QuicheGetCommonSourcePath(), + "/platform/api/testdir/testfile"); + std::vector dirs; + std::vector files; + bool success = EnumerateDirectory(path, dirs, files); + EXPECT_FALSE(success); +} + +TEST(QuicheFileUtilsTest, EnumerateDirectoryRecursively) { + std::vector expected_paths = {"a/b/c/d/e", "a/subdir/testfile", + "a/z", "testfile", "README.md"}; + + std::string root_path = + absl::StrCat(QuicheGetCommonSourcePath(), "/platform/api/testdir"); + for (std::string& path : expected_paths) { + // For Windows, use Windows path separators. + if (JoinPath("a", "b") == "a\\b") { + absl::c_replace(path, '/', '\\'); + } + + path = JoinPath(root_path, path); + } + + std::vector files; + bool success = EnumerateDirectoryRecursively(root_path, files); + EXPECT_TRUE(success); + EXPECT_THAT(files, UnorderedElementsAreArray(expected_paths)); +} + +} // namespace +} // namespace test +} // namespace quiche diff --git a/gquiche/common/platform/api/quiche_flag_utils.h b/gquiche/common/platform/api/quiche_flag_utils.h index 4678fe0a..71dc607f 100644 --- a/gquiche/common/platform/api/quiche_flag_utils.h +++ b/gquiche/common/platform/api/quiche_flag_utils.h @@ -13,7 +13,7 @@ #define QUICHE_RESTART_FLAG_COUNT QUICHE_RESTART_FLAG_COUNT_IMPL #define QUICHE_RESTART_FLAG_COUNT_N QUICHE_RESTART_FLAG_COUNT_N_IMPL -#define QUICHE_CODE_COUNT QUICHE_CODE_COUNT_IMPL +#define QUICHE_CODE_COUNT(x) #define QUICHE_CODE_COUNT_N QUICHE_CODE_COUNT_N_IMPL #endif // QUICHE_COMMON_PLATFORM_API_QUICHE_FLAG_UTILS_H_ diff --git a/gquiche/common/platform/api/quiche_prefetch.h b/gquiche/common/platform/api/quiche_prefetch.h new file mode 100644 index 00000000..d6174a6e --- /dev/null +++ b/gquiche/common/platform/api/quiche_prefetch.h @@ -0,0 +1,39 @@ +// Copyright 2017 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_PREFETCH_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_PREFETCH_H_ + +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h" + +namespace quiche { + +// Move data into the cache before it is read, or "prefetch" it. +// +// The value of `addr` is the address of the memory to prefetch. If +// the target and compiler support it, data prefetch instructions are +// generated. If the prefetch is done some time before the memory is +// read, it may be in the cache by the time the read occurs. +// +// The function names specify the temporal locality heuristic applied, +// using the names of Intel prefetch instructions: +// +// T0 - high degree of temporal locality; data should be left in as +// many levels of the cache possible +// T1 - moderate degree of temporal locality +// T2 - low degree of temporal locality +// Nta - no temporal locality, data need not be left in the cache +// after the read +// +// Incorrect or gratuitous use of these functions can degrade +// performance, so use them only when representative benchmarks show +// an improvement. + +inline void QuichePrefetchT0(const void* addr) { + return QuichePrefetchT0Impl(addr); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_PREFETCH_H_ diff --git a/gquiche/common/platform/api/quiche_test.h b/gquiche/common/platform/api/quiche_test.h index 81b9aad9..0597e686 100644 --- a/gquiche/common/platform/api/quiche_test.h +++ b/gquiche/common/platform/api/quiche_test.h @@ -5,11 +5,26 @@ #ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_H_ #define QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_H_ -#include "platform/quiche_test_impl.h" +#include "platform/quiche_platform_impl/quiche_test_impl.h" using QuicheTest = quiche::test::QuicheTest; template using QuicheTestWithParam = quiche::test::QuicheTestWithParamImpl; +namespace quiche { +namespace test { + +// Returns the path to quiche/common directory where the test data could be +// located. +inline std::string QuicheGetCommonSourcePath() { + return QuicheGetCommonSourcePathImpl(); +} + +} // namespace test +} // namespace quiche + +#define EXPECT_QUICHE_DEBUG_DEATH(condition, message) \ + EXPECT_QUICHE_DEBUG_DEATH_IMPL(condition, message) + #endif // QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_H_ diff --git a/gquiche/common/platform/api/quiche_test_helpers.h b/gquiche/common/platform/api/quiche_test_helpers.h index 9af22edd..097f305c 100644 --- a/gquiche/common/platform/api/quiche_test_helpers.h +++ b/gquiche/common/platform/api/quiche_test_helpers.h @@ -1,7 +1,7 @@ #ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_HELPERS_H_ #define QUICHE_COMMON_PLATFORM_API_QUICHE_TEST_HELPERS_H_ -#include "platform/quiche_test_helpers_impl.h" +#include "platform/quiche_platform_impl/quiche_test_helpers_impl.h" #define EXPECT_QUICHE_BUG EXPECT_QUICHE_BUG_IMPL diff --git a/gquiche/common/platform/api/quiche_thread_local.h b/gquiche/common/platform/api/quiche_thread_local.h new file mode 100644 index 00000000..faab5154 --- /dev/null +++ b/gquiche/common/platform/api/quiche_thread_local.h @@ -0,0 +1,27 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_THREAD_LOCAL_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_THREAD_LOCAL_H_ + +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_thread_local_impl.h" + +// Define a thread local |type*| with |name|. Conceptually, this is a +// +// static thread_local type* name = nullptr; +// +// It is wrapped in a macro because the thread_local keyword is banned from +// Chromium. +#define DEFINE_QUICHE_THREAD_LOCAL_POINTER(name, type) \ + DEFINE_QUICHE_THREAD_LOCAL_POINTER_IMPL(name, type) + +// Get the value of |name| for the current thread. +#define GET_QUICHE_THREAD_LOCAL_POINTER(name) \ + GET_QUICHE_THREAD_LOCAL_POINTER_IMPL(name) + +// Set the |value| of |name| for the current thread. +#define SET_QUICHE_THREAD_LOCAL_POINTER(name, value) \ + SET_QUICHE_THREAD_LOCAL_POINTER_IMPL(name, value) + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_THREAD_LOCAL_H_ diff --git a/gquiche/common/platform/api/quiche_time_utils.h b/gquiche/common/platform/api/quiche_time_utils.h index 2917971d..a029cb24 100644 --- a/gquiche/common/platform/api/quiche_time_utils.h +++ b/gquiche/common/platform/api/quiche_time_utils.h @@ -7,7 +7,7 @@ #include -#include "platform/quiche_platform_impl/quiche_time_utils_impl.h" +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_time_utils_impl.h" namespace quiche { diff --git a/gquiche/common/platform/api/quiche_url_utils.h b/gquiche/common/platform/api/quiche_url_utils.h new file mode 100644 index 00000000..98ff59b5 --- /dev/null +++ b/gquiche/common/platform/api/quiche_url_utils.h @@ -0,0 +1,38 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_API_QUICHE_URL_UTILS_H_ +#define QUICHE_COMMON_PLATFORM_API_QUICHE_URL_UTILS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h" + +namespace quiche { + +// Produces concrete URLs in |target| from templated ones in |uri_template|. +// Parameters are URL-encoded. Collects the names of any expanded variables in +// |vars_found|. Returns true if the template was parseable, false if it was +// malformed. +inline bool ExpandURITemplate( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + std::string* target, + absl::flat_hash_set* vars_found = nullptr) { + return ExpandURITemplateImpl(uri_template, parameters, target, vars_found); +} + +// Decodes a URL-encoded string and converts it to ASCII. If the decoded input +// contains non-ASCII characters, decoding fails and absl::nullopt is returned. +inline absl::optional AsciiUrlDecode(absl::string_view input) { + return AsciiUrlDecodeImpl(input); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_API_QUICHE_URL_UTILS_H_ diff --git a/gquiche/common/platform/api/quiche_url_utils_test.cc b/gquiche/common/platform/api/quiche_url_utils_test.cc new file mode 100644 index 00000000..0ef16f76 --- /dev/null +++ b/gquiche/common/platform/api/quiche_url_utils_test.cc @@ -0,0 +1,80 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/common/platform/api/quiche_url_utils.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/types/optional.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace { + +void ValidateExpansion( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + const std::string& expected_expansion, + const absl::flat_hash_set& expected_vars_found) { + absl::flat_hash_set vars_found; + std::string target; + ASSERT_TRUE( + ExpandURITemplate(uri_template, parameters, &target, &vars_found)); + EXPECT_EQ(expected_expansion, target); + EXPECT_EQ(vars_found, expected_vars_found); +} + +TEST(QuicheUrlUtilsTest, Basic) { + ValidateExpansion("/{foo}/{bar}/", {{"foo", "123"}, {"bar", "456"}}, + "/123/456/", {"foo", "bar"}); +} + +TEST(QuicheUrlUtilsTest, ExtraParameter) { + ValidateExpansion("/{foo}/{bar}/{baz}/", {{"foo", "123"}, {"bar", "456"}}, + "/123/456//", {"foo", "bar"}); +} + +TEST(QuicheUrlUtilsTest, MissingParameter) { + ValidateExpansion("/{foo}/{baz}/", {{"foo", "123"}, {"bar", "456"}}, "/123//", + {"foo"}); +} + +TEST(QuicheUrlUtilsTest, RepeatedParameter) { + ValidateExpansion("/{foo}/{bar}/{foo}/", {{"foo", "123"}, {"bar", "456"}}, + "/123/456/123/", {"foo", "bar"}); +} + +TEST(QuicheUrlUtilsTest, URLEncoding) { + ValidateExpansion("/{foo}/{bar}/", {{"foo", "123"}, {"bar", ":"}}, + "/123/%3A/", {"foo", "bar"}); +} + +void ValidateUrlDecode(const std::string& input, + const absl::optional& expected_output) { + absl::optional decode_result = AsciiUrlDecode(input); + if (!expected_output.has_value()) { + EXPECT_FALSE(decode_result.has_value()); + return; + } + ASSERT_TRUE(decode_result.has_value()); + EXPECT_EQ(decode_result.value(), expected_output); +} + +TEST(QuicheUrlUtilsTest, DecodeNoChange) { + ValidateUrlDecode("foobar", "foobar"); +} + +TEST(QuicheUrlUtilsTest, DecodeReplace) { + ValidateUrlDecode("%7Bfoobar%7D", "{foobar}"); +} + +TEST(QuicheUrlUtilsTest, DecodeFail) { + ValidateUrlDecode("%FF", absl::nullopt); +} + +} // namespace +} // namespace quiche diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h b/gquiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h new file mode 100644 index 00000000..b384b294 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_bug_tracker_impl.h @@ -0,0 +1,15 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_BUG_TRACKER_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_BUG_TRACKER_IMPL_H_ + +#include "gquiche/common/platform/api/quiche_logging.h" + +#define QUICHE_BUG_IMPL(b) QUICHE_LOG(DFATAL) +#define QUICHE_BUG_IF_IMPL(b, condition) QUICHE_LOG_IF(DFATAL, condition) +#define QUICHE_PEER_BUG_IMPL(b) QUICHE_LOG(DFATAL) +#define QUICHE_PEER_BUG_IF_IMPL(b, condition) QUICHE_LOG_IF(DFATAL, condition) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_BUG_TRACKER_IMPL_H_ diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h b/gquiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h index 2e18dd7a..3c78ed51 100644 --- a/gquiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h @@ -5,16 +5,17 @@ #ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CONTAINERS_IMPL_H_ #define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_CONTAINERS_IMPL_H_ -#include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wall" + +#include "absl/container/btree_set.h" + +#pragma GCC diagnostic pop namespace quiche { -// Represents a double-ended queue which may be backed by a list or a flat -// circular buffer. -// -// DOES NOT GUARANTEE POINTER OR ITERATOR STABILITY! -template -using QuicheDequeImpl = std::deque; +template +using QuicheSmallOrderedSetImpl = absl::btree_set; } // namespace quiche diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc b/gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc new file mode 100644 index 00000000..aa30ce12 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.cc @@ -0,0 +1,182 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h" + +#if defined(_WIN32) +#include +#else +#include +#include +#include +#include +#endif // defined(_WIN32) + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" +#include "absl/types/optional.h" + +namespace quiche { + +#if defined(_WIN32) +std::string JoinPathImpl(absl::string_view a, absl::string_view b) { + if (a.empty()) { + return std::string(b); + } + if (b.empty()) { + return std::string(a); + } + // Win32 actually provides two different APIs for combining paths; one of them + // has issues that could potentially lead to buffer overflow, and another is + // not supported in Windows 7, which is why we're doing it manually. + a = absl::StripSuffix(a, "/"); + a = absl::StripSuffix(a, "\\"); + return absl::StrCat(a, "\\", b); +} +#else +std::string JoinPathImpl(absl::string_view a, absl::string_view b) { + if (a.empty()) { + return std::string(b); + } + if (b.empty()) { + return std::string(a); + } + return absl::StrCat(absl::StripSuffix(a, "/"), "/", b); +} +#endif // defined(_WIN32) + +absl::optional ReadFileContentsImpl(absl::string_view file) { + std::ifstream input_file(std::string{file}, std::ios::binary); + if (!input_file || !input_file.is_open()) { + return absl::nullopt; + } + + input_file.seekg(0, std::ios_base::end); + auto file_size = input_file.tellg(); + if (!input_file) { + return absl::nullopt; + } + input_file.seekg(0, std::ios_base::beg); + + std::string output; + output.resize(file_size); + input_file.read(&output[0], file_size); + if (!input_file) { + return absl::nullopt; + } + + return output; +} + +#if defined(_WIN32) + +class ScopedDir { + public: + ScopedDir(HANDLE dir) : dir_(dir) {} + ~ScopedDir() { + if (dir_ != INVALID_HANDLE_VALUE) { + // The API documentation explicitly says that CloseHandle() should not be + // used on directory search handles. + FindClose(dir_); + dir_ = INVALID_HANDLE_VALUE; + } + } + + HANDLE get() { return dir_; } + + private: + HANDLE dir_; +}; + +bool EnumerateDirectoryImpl(absl::string_view path, + std::vector& directories, + std::vector& files) { + std::string path_owned(path); + + // Explicitly check that the directory we are trying to search is in fact a + // directory. + DWORD attributes = GetFileAttributesA(path_owned.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES) { + return false; + } + if ((attributes & FILE_ATTRIBUTE_DIRECTORY) == 0) { + return false; + } + + std::string search_path = JoinPathImpl(path, "*"); + WIN32_FIND_DATAA file_data; + ScopedDir dir(FindFirstFileA(search_path.c_str(), &file_data)); + if (dir.get() == INVALID_HANDLE_VALUE) { + return GetLastError() == ERROR_FILE_NOT_FOUND; + } + do { + std::string filename(file_data.cFileName); + if (filename == "." || filename == "..") { + continue; + } + if ((file_data.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) != 0) { + directories.push_back(std::move(filename)); + } else { + files.push_back(std::move(filename)); + } + } while (FindNextFileA(dir.get(), &file_data)); + return GetLastError() == ERROR_NO_MORE_FILES; +} + +#else // defined(_WIN32) + +class ScopedDir { + public: + ScopedDir(DIR* dir) : dir_(dir) {} + ~ScopedDir() { + if (dir_ != nullptr) { + closedir(dir_); + dir_ = nullptr; + } + } + + DIR* get() { return dir_; } + + private: + DIR* dir_; +}; + +bool EnumerateDirectoryImpl(absl::string_view path, + std::vector& directories, + std::vector& files) { + std::string path_owned(path); + ScopedDir dir(opendir(path_owned.c_str())); + if (dir.get() == nullptr) { + return false; + } + + dirent* entry; + while ((entry = readdir(dir.get()))) { + const std::string filename(entry->d_name); + if (filename == "." || filename == "..") { + continue; + } + + const std::string entry_path = JoinPathImpl(path, filename); + struct stat stat_entry; + if (stat(entry_path.c_str(), &stat_entry) != 0) { + return false; + } + if (S_ISREG(stat_entry.st_mode)) { + files.push_back(std::move(filename)); + } else if (S_ISDIR(stat_entry.st_mode)) { + directories.push_back(std::move(filename)); + } + } + return true; +} + +#endif // defined(_WIN32) + +} // namespace quiche diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h b/gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h new file mode 100644 index 00000000..ad5ff1a9 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_file_utils_impl.h @@ -0,0 +1,26 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FILE_UTILS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FILE_UTILS_IMPL_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +namespace quiche { + +std::string JoinPathImpl(absl::string_view a, absl::string_view b); + +absl::optional ReadFileContentsImpl(absl::string_view file); + +bool EnumerateDirectoryImpl(absl::string_view path, + std::vector& directories, + std::vector& files); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_FILE_UTILS_IMPL_H_ diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h b/gquiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h new file mode 100644 index 00000000..89454817 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_prefetch_impl.h @@ -0,0 +1,28 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_PREFETCH_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_PREFETCH_IMPL_H_ + +#if defined(_MSC_VER) +#include +#endif + +namespace quiche { + +inline void QuichePrefetchT0Impl(const void* addr) { +#if !defined(DISABLE_BUILTIN_PREFETCH) +#if defined(__GNUC__) || (defined(_M_ARM64) && defined(__clang__)) + __builtin_prefetch(addr, 0, 3); +#elif defined(_MSC_VER) + _mm_prefetch(reinterpret_cast(addr), _MM_HINT_T0); +#else + (void*)addr; +#endif +#endif // !defined(DISABLE_BUILTIN_PREFETCH) +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_PREFETCH_IMPL_H_ diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_sleep_impl.h b/gquiche/common/platform/default/quiche_platform_impl/quiche_sleep_impl.h new file mode 100644 index 00000000..6879d1f4 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_sleep_impl.h @@ -0,0 +1,22 @@ +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SLEEP_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SLEEP_IMPL_H_ + +#include "gquiche/quic/core/quic_time.h" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wall" + +#include "absl/time/clock.h" +#include "absl/time/time.h" + +#pragma GCC diagnostic pop + +namespace quic { + +inline void QuicSleepImpl(QuicTime::Delta duration) { + absl::SleepFor(absl::Microseconds(duration.ToMicroseconds())); +} + +} // namespace quic + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_SLEEP_IMPL_H_ diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_thread_local_impl.h b/gquiche/common/platform/default/quiche_platform_impl/quiche_thread_local_impl.h new file mode 100644 index 00000000..5ebea4c5 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_thread_local_impl.h @@ -0,0 +1,24 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_THREAD_LOCAL_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_THREAD_LOCAL_IMPL_H_ + +#define DEFINE_QUICHE_THREAD_LOCAL_POINTER_IMPL(name, type) \ + struct QuicheThreadLocalPointer_##name { \ + static type** Instance() { \ + static thread_local type* instance = nullptr; \ + return &instance; \ + } \ + static type* Get() { return *Instance(); } \ + static void Set(type* ptr) { *Instance() = ptr; } \ + } + +#define GET_QUICHE_THREAD_LOCAL_POINTER_IMPL(name) \ + QuicheThreadLocalPointer_##name::Get() + +#define SET_QUICHE_THREAD_LOCAL_POINTER_IMPL(name, value) \ + QuicheThreadLocalPointer_##name::Set(value) + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_THREAD_LOCAL_IMPL_H_ diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc b/gquiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc new file mode 100644 index 00000000..00bce550 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.cc @@ -0,0 +1,78 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_platform_impl/quiche_url_utils_impl.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "url/url_canon.h" +#include "url/url_util.h" + +namespace quiche { + +bool ExpandURITemplateImpl( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + std::string* target, absl::flat_hash_set* vars_found) { + absl::flat_hash_set found; + std::string result = uri_template; + for (const auto& pair : parameters) { + const std::string& name = pair.first; + const std::string& value = pair.second; + std::string name_input = absl::StrCat("{", name, "}"); + url::RawCanonOutputT canon_value; + url::EncodeURIComponent(value.c_str(), value.length(), &canon_value); + std::string encoded_value(canon_value.data(), canon_value.length()); + int num_replaced = + absl::StrReplaceAll({{name_input, encoded_value}}, &result); + if (num_replaced > 0) { + found.insert(name); + } + } + // Remove any remaining variables that were not present in |parameters|. + while (true) { + size_t start = result.find('{'); + if (start == std::string::npos) { + break; + } + size_t end = result.find('}'); + if (end == std::string::npos || end <= start) { + return false; + } + result.erase(start, (end - start) + 1); + } + if (vars_found != nullptr) { + *vars_found = found; + } + *target = result; + return true; +} + +absl::optional AsciiUrlDecodeImpl(absl::string_view input) { + std::string input_encoded = std::string(input); + url::RawCanonOutputW<1024> canon_output; + url::DecodeURLEscapeSequences(input_encoded.c_str(), input_encoded.length(), + &canon_output); + std::string output; + output.reserve(canon_output.length()); + for (int i = 0; i < canon_output.length(); i++) { + const uint16_t c = reinterpret_cast(canon_output.data())[i]; + if (c > std::numeric_limits::max()) { + return absl::nullopt; + } + output += static_cast(c); + } + return output; +} + +} // namespace quiche diff --git a/gquiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h b/gquiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h new file mode 100644 index 00000000..3ceadb67 --- /dev/null +++ b/gquiche/common/platform/default/quiche_platform_impl/quiche_url_utils_impl.h @@ -0,0 +1,35 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_URL_UTILS_IMPL_H_ +#define QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_URL_UTILS_IMPL_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "gquiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// Produces concrete URLs in |target| from templated ones in |uri_template|. +// Parameters are URL-encoded. Collects the names of any expanded variables in +// |vars_found|. Supports level 1 templates as specified in RFC 6570. Returns +// true if the template was parseable, false if it was malformed. +QUICHE_EXPORT_PRIVATE bool ExpandURITemplateImpl( + const std::string& uri_template, + const absl::flat_hash_map& parameters, + std::string* target, + absl::flat_hash_set* vars_found = nullptr); + +// Decodes a URL-encoded string and converts it to ASCII. If the decoded input +// contains non-ASCII characters, decoding fails and absl::nullopt is returned. +QUICHE_EXPORT_PRIVATE absl::optional AsciiUrlDecodeImpl( + absl::string_view input); + +} // namespace quiche + +#endif // QUICHE_COMMON_PLATFORM_DEFAULT_QUICHE_PLATFORM_IMPL_QUICHE_URL_UTILS_IMPL_H_ diff --git a/gquiche/common/print_elements.h b/gquiche/common/print_elements.h new file mode 100644 index 00000000..b60688ec --- /dev/null +++ b/gquiche/common/print_elements.h @@ -0,0 +1,37 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_PRINT_ELEMENTS_H_ +#define QUICHE_COMMON_PRINT_ELEMENTS_H_ + +#include +#include +#include + +#include "gquiche/common/platform/api/quiche_export.h" + +namespace quiche { + +// Print elements of any iterable container that has cbegin() and cend() methods +// and the elements have operator<<(ostream) override. +template +QUICHE_EXPORT_PRIVATE inline std::string PrintElements(const T& container) { + std::stringstream debug_string; + debug_string << "{"; + auto it = container.cbegin(); + if (it != container.cend()) { + debug_string << *it; + ++it; + while (it != container.cend()) { + debug_string << ", " << *it; + ++it; + } + } + debug_string << "}"; + return debug_string.str(); +} + +} // namespace quiche + +#endif // QUICHE_COMMON_PRINT_ELEMENTS_H_ diff --git a/gquiche/common/print_elements_test.cc b/gquiche/common/print_elements_test.cc new file mode 100644 index 00000000..1dd05dd2 --- /dev/null +++ b/gquiche/common/print_elements_test.cc @@ -0,0 +1,61 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/common/print_elements.h" + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "quic/core/quic_error_codes.h" +#include "gquiche/common/platform/api/quiche_test.h" + +using quic::QuicIetfTransportErrorCodes; + +namespace quiche { +namespace test { +namespace { + +TEST(PrintElementsTest, Empty) { + std::vector empty{}; + EXPECT_EQ("{}", PrintElements(empty)); +} + +TEST(PrintElementsTest, StdContainers) { + std::vector one{"foo"}; + EXPECT_EQ("{foo}", PrintElements(one)); + + std::list two{"foo", "bar"}; + EXPECT_EQ("{foo, bar}", PrintElements(two)); + + std::deque three{"foo", "bar", "baz"}; + EXPECT_EQ("{foo, bar, baz}", PrintElements(three)); +} + +// QuicIetfTransportErrorCodes has a custom operator<<() override. +TEST(PrintElementsTest, CustomPrinter) { + std::vector empty{}; + EXPECT_EQ("{}", PrintElements(empty)); + + std::list one{ + QuicIetfTransportErrorCodes::NO_IETF_QUIC_ERROR}; + EXPECT_EQ("{NO_IETF_QUIC_ERROR}", PrintElements(one)); + + std::vector two{ + QuicIetfTransportErrorCodes::FLOW_CONTROL_ERROR, + QuicIetfTransportErrorCodes::STREAM_LIMIT_ERROR}; + EXPECT_EQ("{FLOW_CONTROL_ERROR, STREAM_LIMIT_ERROR}", PrintElements(two)); + + std::list three{ + QuicIetfTransportErrorCodes::CONNECTION_ID_LIMIT_ERROR, + QuicIetfTransportErrorCodes::PROTOCOL_VIOLATION, + QuicIetfTransportErrorCodes::INVALID_TOKEN}; + EXPECT_EQ("{CONNECTION_ID_LIMIT_ERROR, PROTOCOL_VIOLATION, INVALID_TOKEN}", + PrintElements(three)); +} + +} // anonymous namespace +} // namespace test +} // namespace quiche diff --git a/gquiche/common/quiche_circular_deque.h b/gquiche/common/quiche_circular_deque.h new file mode 100644 index 00000000..c50b17f4 --- /dev/null +++ b/gquiche/common/quiche_circular_deque.h @@ -0,0 +1,759 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_CIRCULAR_DEQUE_H_ +#define QUICHE_COMMON_QUICHE_CIRCULAR_DEQUE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +// QuicheCircularDeque is a STL-style container that is similar to std::deque in +// API and std::vector in capacity management. The goal is to optimize a common +// QUIC use case where we keep adding new elements to the end and removing old +// elements from the beginning, under such scenarios, if the container's size() +// remain relatively stable, QuicheCircularDeque requires little to no memory +// allocations or deallocations. +// +// The implementation, as the name suggests, uses a flat circular buffer to hold +// all elements. At any point in time, either +// a) All elements are placed in a contiguous portion of this buffer, like a +// c-array, or +// b) Elements are phycially divided into two parts: the first part occupies the +// end of the buffer and the second part occupies the beginning of the +// buffer. +// +// Currently, elements can only be pushed or poped from either ends, it can't be +// inserted or erased in the middle. +// +// TODO(wub): Make memory grow/shrink strategies customizable. +template > +class QUICHE_NO_EXPORT QuicheCircularDeque { + using AllocatorTraits = std::allocator_traits; + + // Pointee is either T or const T. + template + class QUICHE_NO_EXPORT basic_iterator { + using size_type = typename AllocatorTraits::size_type; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = typename AllocatorTraits::value_type; + using difference_type = typename AllocatorTraits::difference_type; + using pointer = Pointee*; + using reference = Pointee&; + + basic_iterator() = default; + + // A copy constructor if Pointee is T. + // A conversion from iterator to const_iterator if Pointee is const T. + basic_iterator( + const basic_iterator& it) // NOLINT(runtime/explicit) + : deque_(it.deque_), index_(it.index_) {} + + // A copy assignment if Pointee is T. + // A assignment from iterator to const_iterator if Pointee is const T. + basic_iterator& operator=(const basic_iterator& it) { + if (this != &it) { + deque_ = it.deque_; + index_ = it.index_; + } + return *this; + } + + reference operator*() const { return *deque_->index_to_address(index_); } + pointer operator->() const { return deque_->index_to_address(index_); } + reference operator[](difference_type i) { return *(*this + i); } + + basic_iterator& operator++() { + Increment(); + return *this; + } + + basic_iterator operator++(int) { + basic_iterator result = *this; + Increment(); + return result; + } + + basic_iterator operator--() { + Decrement(); + return *this; + } + + basic_iterator operator--(int) { + basic_iterator result = *this; + Decrement(); + return result; + } + + friend basic_iterator operator+(const basic_iterator& it, + difference_type delta) { + basic_iterator result = it; + result.IncrementBy(delta); + return result; + } + + basic_iterator& operator+=(difference_type delta) { + IncrementBy(delta); + return *this; + } + + friend basic_iterator operator-(const basic_iterator& it, + difference_type delta) { + basic_iterator result = it; + result.IncrementBy(-delta); + return result; + } + + basic_iterator& operator-=(difference_type delta) { + IncrementBy(-delta); + return *this; + } + + friend difference_type operator-(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.ExternalPosition() - rhs.ExternalPosition(); + } + + friend bool operator==(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.index_ == rhs.index_; + } + + friend bool operator!=(const basic_iterator& lhs, + const basic_iterator& rhs) { + return !(lhs == rhs); + } + + friend bool operator<(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.ExternalPosition() < rhs.ExternalPosition(); + } + + friend bool operator<=(const basic_iterator& lhs, + const basic_iterator& rhs) { + return !(lhs > rhs); + } + + friend bool operator>(const basic_iterator& lhs, + const basic_iterator& rhs) { + return lhs.ExternalPosition() > rhs.ExternalPosition(); + } + + friend bool operator>=(const basic_iterator& lhs, + const basic_iterator& rhs) { + return !(lhs < rhs); + } + + private: + basic_iterator(const QuicheCircularDeque* deque, size_type index) + : deque_(deque), index_(index) {} + + void Increment() { + QUICHE_DCHECK_LE(ExternalPosition() + 1, deque_->size()); + index_ = deque_->index_next(index_); + } + + void Decrement() { + QUICHE_DCHECK_GE(ExternalPosition(), 1u); + index_ = deque_->index_prev(index_); + } + + void IncrementBy(difference_type delta) { + if (delta >= 0) { + // After increment we are before or at end(). + QUICHE_DCHECK_LE(static_cast(ExternalPosition() + delta), + deque_->size()); + } else { + // After decrement we are after or at begin(). + QUICHE_DCHECK_GE(ExternalPosition(), static_cast(-delta)); + } + index_ = deque_->index_increment_by(index_, delta); + } + + size_type ExternalPosition() const { + if (index_ >= deque_->begin_) { + return index_ - deque_->begin_; + } + return index_ + deque_->data_capacity() - deque_->begin_; + } + + friend class QuicheCircularDeque; + const QuicheCircularDeque* deque_ = nullptr; + size_type index_ = 0; + }; + + public: + using allocator_type = typename AllocatorTraits::allocator_type; + using value_type = typename AllocatorTraits::value_type; + using size_type = typename AllocatorTraits::size_type; + using difference_type = typename AllocatorTraits::difference_type; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = typename AllocatorTraits::pointer; + using const_pointer = typename AllocatorTraits::const_pointer; + using iterator = basic_iterator; + using const_iterator = basic_iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + QuicheCircularDeque() : QuicheCircularDeque(allocator_type()) {} + explicit QuicheCircularDeque(const allocator_type& alloc) + : allocator_and_data_(alloc) {} + + QuicheCircularDeque(size_type count, + const T& value, + const Allocator& alloc = allocator_type()) + : allocator_and_data_(alloc) { + resize(count, value); + } + + explicit QuicheCircularDeque(size_type count, + const Allocator& alloc = allocator_type()) + : allocator_and_data_(alloc) { + resize(count); + } + + template < + class InputIt, + typename = std::enable_if_t::iterator_category>::value>> + QuicheCircularDeque(InputIt first, + InputIt last, + const Allocator& alloc = allocator_type()) + : allocator_and_data_(alloc) { + AssignRange(first, last); + } + + QuicheCircularDeque(const QuicheCircularDeque& other) + : QuicheCircularDeque( + other, + AllocatorTraits::select_on_container_copy_construction( + other.allocator_and_data_.allocator())) {} + + QuicheCircularDeque(const QuicheCircularDeque& other, + const allocator_type& alloc) + : allocator_and_data_(alloc) { + assign(other.begin(), other.end()); + } + + QuicheCircularDeque(QuicheCircularDeque&& other) + : begin_(other.begin_), + end_(other.end_), + allocator_and_data_(std::move(other.allocator_and_data_)) { + other.begin_ = other.end_ = 0; + other.allocator_and_data_.data = nullptr; + other.allocator_and_data_.data_capacity = 0; + } + + QuicheCircularDeque(QuicheCircularDeque&& other, const allocator_type& alloc) + : allocator_and_data_(alloc) { + MoveRetainAllocator(std::move(other)); + } + + QuicheCircularDeque(std::initializer_list init, + const allocator_type& alloc = allocator_type()) + : QuicheCircularDeque(init.begin(), init.end(), alloc) {} + + QuicheCircularDeque& operator=(const QuicheCircularDeque& other) { + if (this == &other) { + return *this; + } + if (AllocatorTraits::propagate_on_container_copy_assignment::value && + (allocator_and_data_.allocator() != + other.allocator_and_data_.allocator())) { + // Destroy all current elements and blocks with the current allocator, + // before switching this to use the allocator propagated from "other". + DestroyAndDeallocateAll(); + begin_ = end_ = 0; + allocator_and_data_ = + AllocatorAndData(other.allocator_and_data_.allocator()); + } + assign(other.begin(), other.end()); + return *this; + } + + QuicheCircularDeque& operator=(QuicheCircularDeque&& other) { + if (this == &other) { + return *this; + } + if (AllocatorTraits::propagate_on_container_move_assignment::value) { + // Take over the storage of "other", along with its allocator. + this->~QuicheCircularDeque(); + new (this) QuicheCircularDeque(std::move(other)); + } else { + MoveRetainAllocator(std::move(other)); + } + return *this; + } + + ~QuicheCircularDeque() { DestroyAndDeallocateAll(); } + + void assign(size_type count, const T& value) { + ClearRetainCapacity(); + reserve(count); + for (size_t i = 0; i < count; ++i) { + emplace_back(value); + } + } + + template < + class InputIt, + typename = std::enable_if_t::iterator_category>::value>> + void assign(InputIt first, InputIt last) { + AssignRange(first, last); + } + + void assign(std::initializer_list ilist) { + assign(ilist.begin(), ilist.end()); + } + + reference at(size_type pos) { + QUICHE_DCHECK(pos < size()) << "pos:" << pos << ", size():" << size(); + size_type index = begin_ + pos; + if (index < data_capacity()) { + return *index_to_address(index); + } + return *index_to_address(index - data_capacity()); + } + + const_reference at(size_type pos) const { + return const_cast(this)->at(pos); + } + + reference operator[](size_type pos) { return at(pos); } + + const_reference operator[](size_type pos) const { return at(pos); } + + reference front() { + QUICHE_DCHECK(!empty()); + return *index_to_address(begin_); + } + + const_reference front() const { + return const_cast(this)->front(); + } + + reference back() { + QUICHE_DCHECK(!empty()); + return *(index_to_address(end_ == 0 ? data_capacity() - 1 : end_ - 1)); + } + + const_reference back() const { + return const_cast(this)->back(); + } + + iterator begin() { return iterator(this, begin_); } + const_iterator begin() const { return const_iterator(this, begin_); } + const_iterator cbegin() const { return const_iterator(this, begin_); } + + iterator end() { return iterator(this, end_); } + const_iterator end() const { return const_iterator(this, end_); } + const_iterator cend() const { return const_iterator(this, end_); } + + reverse_iterator rbegin() { return reverse_iterator(end()); } + const_reverse_iterator rbegin() const { + return const_reverse_iterator(end()); + } + const_reverse_iterator crbegin() const { return rbegin(); } + + reverse_iterator rend() { return reverse_iterator(begin()); } + const_reverse_iterator rend() const { + return const_reverse_iterator(begin()); + } + const_reverse_iterator crend() const { return rend(); } + + size_type capacity() const { + return data_capacity() == 0 ? 0 : data_capacity() - 1; + } + + void reserve(size_type new_cap) { + if (new_cap > capacity()) { + Relocate(new_cap); + } + } + + // Remove all elements. Leave capacity unchanged. + void clear() { ClearRetainCapacity(); } + + bool empty() const { return begin_ == end_; } + + size_type size() const { + if (begin_ <= end_) { + return end_ - begin_; + } + return data_capacity() + end_ - begin_; + } + + void resize(size_type count) { ResizeInternal(count); } + + void resize(size_type count, const value_type& value) { + ResizeInternal(count, value); + } + + void push_front(const T& value) { emplace_front(value); } + void push_front(T&& value) { emplace_front(std::move(value)); } + + template + reference emplace_front(Args&&... args) { + MaybeExpandCapacity(1); + begin_ = index_prev(begin_); + new (index_to_address(begin_)) T(std::forward(args)...); + return front(); + } + + void push_back(const T& value) { emplace_back(value); } + void push_back(T&& value) { emplace_back(std::move(value)); } + + template + reference emplace_back(Args&&... args) { + MaybeExpandCapacity(1); + new (index_to_address(end_)) T(std::forward(args)...); + end_ = index_next(end_); + return back(); + } + + void pop_front() { + QUICHE_DCHECK(!empty()); + DestroyByIndex(begin_); + begin_ = index_next(begin_); + MaybeShrinkCapacity(); + } + + size_type pop_front_n(size_type count) { + size_type num_elements_to_pop = std::min(count, size()); + size_type new_begin = index_increment_by(begin_, num_elements_to_pop); + DestroyRange(begin_, new_begin); + begin_ = new_begin; + MaybeShrinkCapacity(); + return num_elements_to_pop; + } + + void pop_back() { + QUICHE_DCHECK(!empty()); + end_ = index_prev(end_); + DestroyByIndex(end_); + MaybeShrinkCapacity(); + } + + size_type pop_back_n(size_type count) { + size_type num_elements_to_pop = std::min(count, size()); + size_type new_end = index_increment_by(end_, -num_elements_to_pop); + DestroyRange(new_end, end_); + end_ = new_end; + MaybeShrinkCapacity(); + return num_elements_to_pop; + } + + void swap(QuicheCircularDeque& other) { + using std::swap; + swap(begin_, other.begin_); + swap(end_, other.end_); + + if (AllocatorTraits::propagate_on_container_swap::value) { + swap(allocator_and_data_, other.allocator_and_data_); + } else { + // When propagate_on_container_swap is false, it is undefined behavior, by + // c++ standard, to swap between two AllocatorAwareContainer(s) with + // unequal allocators. + QUICHE_DCHECK(get_allocator() == other.get_allocator()) + << "Undefined swap behavior"; + swap(allocator_and_data_.data, other.allocator_and_data_.data); + swap(allocator_and_data_.data_capacity, + other.allocator_and_data_.data_capacity); + } + } + + friend void swap(QuicheCircularDeque& lhs, QuicheCircularDeque& rhs) { + lhs.swap(rhs); + } + + allocator_type get_allocator() const { + return allocator_and_data_.allocator(); + } + + friend bool operator==(const QuicheCircularDeque& lhs, + const QuicheCircularDeque& rhs) { + return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end()); + } + + friend bool operator!=(const QuicheCircularDeque& lhs, + const QuicheCircularDeque& rhs) { + return !(lhs == rhs); + } + + friend QUICHE_NO_EXPORT std::ostream& operator<<( + std::ostream& os, + const QuicheCircularDeque& dq) { + os << "{"; + for (size_type pos = 0; pos != dq.size(); ++pos) { + if (pos != 0) { + os << ","; + } + os << " " << dq[pos]; + } + os << " }"; + return os; + } + + private: + void MoveRetainAllocator(QuicheCircularDeque&& other) { + if (get_allocator() == other.get_allocator()) { + // Take over the storage of "other", with which we share an allocator. + DestroyAndDeallocateAll(); + + begin_ = other.begin_; + end_ = other.end_; + allocator_and_data_.data = other.allocator_and_data_.data; + allocator_and_data_.data_capacity = + other.allocator_and_data_.data_capacity; + + other.begin_ = other.end_ = 0; + other.allocator_and_data_.data = nullptr; + other.allocator_and_data_.data_capacity = 0; + } else { + // We cannot take over of the storage from "other", since it has a + // different allocator; we're stuck move-assigning elements individually. + ClearRetainCapacity(); + for (auto& elem : other) { + push_back(std::move(elem)); + } + other.clear(); + } + } + + template < + typename InputIt, + typename = std::enable_if_t::iterator_category>::value>> + void AssignRange(InputIt first, InputIt last) { + ClearRetainCapacity(); + if (std::is_base_of< + std::random_access_iterator_tag, + typename std::iterator_traits::iterator_category>::value) { + reserve(std::distance(first, last)); + } + for (; first != last; ++first) { + emplace_back(*first); + } + } + + // WARNING: begin_, end_ and allocator_and_data_ are not modified. + void DestroyAndDeallocateAll() { + DestroyRange(begin_, end_); + + if (data_capacity() > 0) { + QUICHE_DCHECK_NE(nullptr, allocator_and_data_.data); + AllocatorTraits::deallocate(allocator_and_data_.allocator(), + allocator_and_data_.data, data_capacity()); + } + } + + void ClearRetainCapacity() { + DestroyRange(begin_, end_); + begin_ = end_ = 0; + } + + void MaybeShrinkCapacity() { + // TODO(wub): Implement a storage policy that actually shrinks. + } + + void MaybeExpandCapacity(size_t num_additional_elements) { + size_t new_size = size() + num_additional_elements; + if (capacity() >= new_size) { + return; + } + + // The minimum amount of additional capacity to grow. + size_t min_additional_capacity = + std::max(MinCapacityIncrement, capacity() / 4); + size_t new_capacity = + std::max(new_size, capacity() + min_additional_capacity); + + Relocate(new_capacity); + } + + void Relocate(size_t new_capacity) { + const size_t num_elements = size(); + QUICHE_DCHECK_GT(new_capacity, num_elements) + << "new_capacity:" << new_capacity << ", num_elements:" << num_elements; + + size_t new_data_capacity = new_capacity + 1; + pointer new_data = AllocatorTraits::allocate( + allocator_and_data_.allocator(), new_data_capacity); + + if (begin_ < end_) { + // Not wrapped. + RelocateUnwrappedRange(begin_, end_, new_data); + } else if (begin_ > end_) { + // Wrapped. + const size_t num_elements_before_wrap = data_capacity() - begin_; + RelocateUnwrappedRange(begin_, data_capacity(), new_data); + RelocateUnwrappedRange(0, end_, new_data + num_elements_before_wrap); + } + + if (data_capacity()) { + AllocatorTraits::deallocate(allocator_and_data_.allocator(), + allocator_and_data_.data, data_capacity()); + } + + allocator_and_data_.data = new_data; + allocator_and_data_.data_capacity = new_data_capacity; + begin_ = 0; + end_ = num_elements; + } + + template + typename std::enable_if::value, void>::type + RelocateUnwrappedRange(size_type begin, size_type end, pointer dest) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + pointer src = index_to_address(begin); + QUICHE_DCHECK_NE(src, nullptr); + memcpy(dest, src, sizeof(T) * (end - begin)); + DestroyRange(begin, end); + } + + template + typename std::enable_if::value && + std::is_move_constructible::value, + void>::type + RelocateUnwrappedRange(size_type begin, size_type end, pointer dest) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + pointer src = index_to_address(begin); + pointer src_end = index_to_address(end); + while (src != src_end) { + new (dest) T(std::move(*src)); + DestroyByAddress(src); + ++dest; + ++src; + } + } + + template + typename std::enable_if::value && + !std::is_move_constructible::value, + void>::type + RelocateUnwrappedRange(size_type begin, size_type end, pointer dest) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + pointer src = index_to_address(begin); + pointer src_end = index_to_address(end); + while (src != src_end) { + new (dest) T(*src); + DestroyByAddress(src); + ++dest; + ++src; + } + } + + template + void ResizeInternal(size_type count, U&&... u) { + if (count > size()) { + // Expanding. + MaybeExpandCapacity(count - size()); + while (size() < count) { + emplace_back(std::forward(u)...); + } + } else { + // Most likely shrinking. No-op if count == size(). + size_type new_end = (begin_ + count) % data_capacity(); + DestroyRange(new_end, end_); + end_ = new_end; + + MaybeShrinkCapacity(); + } + } + + void DestroyRange(size_type begin, size_type end) const { + if (std::is_trivially_destructible::value) { + return; + } + if (end >= begin) { + DestroyUnwrappedRange(begin, end); + } else { + DestroyUnwrappedRange(begin, data_capacity()); + DestroyUnwrappedRange(0, end); + } + } + + // Should only be called from DestroyRange. + void DestroyUnwrappedRange(size_type begin, size_type end) const { + QUICHE_DCHECK_LE(begin, end) << "begin:" << begin << ", end:" << end; + for (; begin != end; ++begin) { + DestroyByIndex(begin); + } + } + + void DestroyByIndex(size_type index) const { + DestroyByAddress(index_to_address(index)); + } + + void DestroyByAddress(pointer address) const { + if (std::is_trivially_destructible::value) { + return; + } + address->~T(); + } + + size_type data_capacity() const { return allocator_and_data_.data_capacity; } + + pointer index_to_address(size_type index) const { + return allocator_and_data_.data + index; + } + + size_type index_prev(size_type index) const { + return index == 0 ? data_capacity() - 1 : index - 1; + } + + size_type index_next(size_type index) const { + return index == data_capacity() - 1 ? 0 : index + 1; + } + + size_type index_increment_by(size_type index, difference_type delta) const { + if (delta == 0) { + return index; + } + + QUICHE_DCHECK_LT(static_cast(std::abs(delta)), data_capacity()); + return (index + data_capacity() + delta) % data_capacity(); + } + + // Empty base-class optimization: bundle storage for our allocator together + // with the fields we had to store anyway, via inheriting from the allocator, + // so this allocator instance doesn't consume any storage when its type has no + // data members. + struct AllocatorAndData : private allocator_type { + explicit AllocatorAndData(const allocator_type& alloc) + : allocator_type(alloc) {} + + const allocator_type& allocator() const { return *this; } + allocator_type& allocator() { return *this; } + + pointer data = nullptr; + size_type data_capacity = 0; + }; + + size_type begin_ = 0; + size_type end_ = 0; + AllocatorAndData allocator_and_data_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_CIRCULAR_DEQUE_H_ diff --git a/gquiche/common/quiche_circular_deque_test.cc b/gquiche/common/quiche_circular_deque_test.cc new file mode 100644 index 00000000..6870c65b --- /dev/null +++ b/gquiche/common/quiche_circular_deque_test.cc @@ -0,0 +1,801 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/common/quiche_circular_deque.h" + +#include +#include +#include +#include +#include + +#include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/common/platform/api/quiche_test.h" + +using testing::ElementsAre; + +namespace quiche { +namespace test { +namespace { + +template class BaseAllocator = std::allocator> +class CountingAllocator : public BaseAllocator { + using BaseType = BaseAllocator; + + public: + using propagate_on_container_copy_assignment = std::true_type; + using propagate_on_container_move_assignment = std::true_type; + using propagate_on_container_swap = std::true_type; + + T* allocate(std::size_t n) { + ++shared_counts_->allocate_count; + return BaseType::allocate(n); + } + + void deallocate(T* ptr, std::size_t n) { + ++shared_counts_->deallocate_count; + return BaseType::deallocate(ptr, n); + } + + size_t allocate_count() const { return shared_counts_->allocate_count; } + + size_t deallocate_count() const { return shared_counts_->deallocate_count; } + + friend bool operator==(const CountingAllocator& lhs, + const CountingAllocator& rhs) { + return lhs.shared_counts_ == rhs.shared_counts_; + } + + friend bool operator!=(const CountingAllocator& lhs, + const CountingAllocator& rhs) { + return !(lhs == rhs); + } + + private: + struct Counts { + size_t allocate_count = 0; + size_t deallocate_count = 0; + }; + + std::shared_ptr shared_counts_ = std::make_shared(); +}; + +template class BaseAllocator = std::allocator> +struct ConfigurableAllocator : public BaseAllocator { + using propagate_on_container_copy_assignment = propagate_on_copy_assignment; + using propagate_on_container_move_assignment = propagate_on_move_assignment; + using propagate_on_container_swap = propagate_on_swap; + + friend bool operator==(const ConfigurableAllocator& /*lhs*/, + const ConfigurableAllocator& /*rhs*/) { + return equality_result; + } + + friend bool operator!=(const ConfigurableAllocator& lhs, + const ConfigurableAllocator& rhs) { + return !(lhs == rhs); + } +}; + +// [1, 2, 3, 4] ==> [4, 1, 2, 3] +template +void ShiftRight(Deque* dq, bool emplace) { + auto back = *(&dq->back()); + dq->pop_back(); + if (emplace) { + dq->emplace_front(back); + } else { + dq->push_front(back); + } +} + +// [1, 2, 3, 4] ==> [2, 3, 4, 1] +template +void ShiftLeft(Deque* dq, bool emplace) { + auto front = *(&dq->front()); + dq->pop_front(); + if (emplace) { + dq->emplace_back(front); + } else { + dq->push_back(front); + } +} + +class QuicheCircularDequeTest : public QuicheTest {}; + +TEST_F(QuicheCircularDequeTest, Empty) { + QuicheCircularDeque dq; + EXPECT_TRUE(dq.empty()); + EXPECT_EQ(0u, dq.size()); + dq.clear(); + dq.push_back(10); + EXPECT_FALSE(dq.empty()); + EXPECT_EQ(1u, dq.size()); + EXPECT_EQ(10, dq.front()); + EXPECT_EQ(10, dq.back()); + dq.pop_front(); + EXPECT_TRUE(dq.empty()); + EXPECT_EQ(0u, dq.size()); + + EXPECT_QUICHE_DEBUG_DEATH(dq.front(), ""); + EXPECT_QUICHE_DEBUG_DEATH(dq.back(), ""); + EXPECT_QUICHE_DEBUG_DEATH(dq.at(0), ""); + EXPECT_QUICHE_DEBUG_DEATH(dq[0], ""); +} + +TEST_F(QuicheCircularDequeTest, Constructor) { + QuicheCircularDeque dq; + EXPECT_TRUE(dq.empty()); + + std::allocator alloc; + QuicheCircularDeque dq1(alloc); + EXPECT_TRUE(dq1.empty()); + + QuicheCircularDeque dq2(8, 100, alloc); + EXPECT_THAT(dq2, ElementsAre(100, 100, 100, 100, 100, 100, 100, 100)); + + QuicheCircularDeque dq3(5, alloc); + EXPECT_THAT(dq3, ElementsAre(0, 0, 0, 0, 0)); + + QuicheCircularDeque dq4_rand_iter(dq3.begin(), dq3.end(), alloc); + EXPECT_THAT(dq4_rand_iter, ElementsAre(0, 0, 0, 0, 0)); + EXPECT_EQ(dq4_rand_iter, dq3); + + std::list dq4_src = {4, 4, 4, 4}; + QuicheCircularDeque dq4_bidi_iter(dq4_src.begin(), dq4_src.end()); + EXPECT_THAT(dq4_bidi_iter, ElementsAre(4, 4, 4, 4)); + + QuicheCircularDeque dq5(dq4_bidi_iter); + EXPECT_THAT(dq5, ElementsAre(4, 4, 4, 4)); + EXPECT_EQ(dq5, dq4_bidi_iter); + + QuicheCircularDeque dq6(dq5, alloc); + EXPECT_THAT(dq6, ElementsAre(4, 4, 4, 4)); + EXPECT_EQ(dq6, dq5); + + QuicheCircularDeque dq7(std::move(*&dq6)); + EXPECT_THAT(dq7, ElementsAre(4, 4, 4, 4)); + EXPECT_TRUE(dq6.empty()); + + QuicheCircularDeque dq8_equal_allocator(std::move(*&dq7), alloc); + EXPECT_THAT(dq8_equal_allocator, ElementsAre(4, 4, 4, 4)); + EXPECT_TRUE(dq7.empty()); + + QuicheCircularDeque> dq8_temp = {5, 6, 7, 8, + 9}; + QuicheCircularDeque> dq8_unequal_allocator( + std::move(*&dq8_temp), CountingAllocator()); + EXPECT_THAT(dq8_unequal_allocator, ElementsAre(5, 6, 7, 8, 9)); + EXPECT_TRUE(dq8_temp.empty()); + + QuicheCircularDeque dq9({3, 4, 5, 6, 7}, alloc); + EXPECT_THAT(dq9, ElementsAre(3, 4, 5, 6, 7)); +} + +TEST_F(QuicheCircularDequeTest, Assign) { + // assign() + QuicheCircularDeque> dq; + dq.assign(7, 1); + EXPECT_THAT(dq, ElementsAre(1, 1, 1, 1, 1, 1, 1)); + EXPECT_EQ(1u, dq.get_allocator().allocate_count()); + + QuicheCircularDeque> dq2; + dq2.assign(dq.begin(), dq.end()); + EXPECT_THAT(dq2, ElementsAre(1, 1, 1, 1, 1, 1, 1)); + EXPECT_EQ(1u, dq2.get_allocator().allocate_count()); + EXPECT_TRUE(std::equal(dq.begin(), dq.end(), dq2.begin(), dq2.end())); + + dq2.assign({2, 2, 2, 2, 2, 2}); + EXPECT_THAT(dq2, ElementsAre(2, 2, 2, 2, 2, 2)); + + // Assign from a non random access iterator. + std::list dq3_src = {3, 3, 3, 3, 3}; + QuicheCircularDeque> dq3; + dq3.assign(dq3_src.begin(), dq3_src.end()); + EXPECT_THAT(dq3, ElementsAre(3, 3, 3, 3, 3)); + EXPECT_LT(1u, dq3.get_allocator().allocate_count()); + + // Copy assignment + dq3 = *&dq3; + EXPECT_THAT(dq3, ElementsAre(3, 3, 3, 3, 3)); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq4, dq5; + dq4.assign(dq3.begin(), dq3.end()); + dq5 = dq4; + EXPECT_THAT(dq5, ElementsAre(3, 3, 3, 3, 3)); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq6, dq7; + dq6.assign(dq3.begin(), dq3.end()); + dq7 = dq6; + EXPECT_THAT(dq7, ElementsAre(3, 3, 3, 3, 3)); + + // Move assignment + dq3 = std::move(*&dq3); + EXPECT_THAT(dq3, ElementsAre(3, 3, 3, 3, 3)); + + ASSERT_TRUE(decltype( + dq3.get_allocator())::propagate_on_container_move_assignment::value); + decltype(dq3) dq8; + dq8 = std::move(*&dq3); + EXPECT_THAT(dq8, ElementsAre(3, 3, 3, 3, 3)); + EXPECT_TRUE(dq3.empty()); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq9, dq10; + dq9.assign(dq8.begin(), dq8.end()); + dq10.assign(dq2.begin(), dq2.end()); + dq9 = std::move(*&dq10); + EXPECT_THAT(dq9, ElementsAre(2, 2, 2, 2, 2, 2)); + EXPECT_TRUE(dq10.empty()); + + QuicheCircularDeque< + int, 3, + ConfigurableAllocator> + dq11, dq12; + dq11.assign(dq8.begin(), dq8.end()); + dq12.assign(dq2.begin(), dq2.end()); + dq11 = std::move(*&dq12); + EXPECT_THAT(dq11, ElementsAre(2, 2, 2, 2, 2, 2)); + EXPECT_TRUE(dq12.empty()); +} + +TEST_F(QuicheCircularDequeTest, Access) { + // at() + // operator[] + // front() + // back() + + QuicheCircularDeque> dq; + dq.push_back(10); + EXPECT_EQ(dq.front(), 10); + EXPECT_EQ(dq.back(), 10); + EXPECT_EQ(dq.at(0), 10); + EXPECT_EQ(dq[0], 10); + dq.front() = 12; + EXPECT_EQ(dq.front(), 12); + EXPECT_EQ(dq.back(), 12); + EXPECT_EQ(dq.at(0), 12); + EXPECT_EQ(dq[0], 12); + + const auto& dqref = dq; + EXPECT_EQ(dqref.front(), 12); + EXPECT_EQ(dqref.back(), 12); + EXPECT_EQ(dqref.at(0), 12); + EXPECT_EQ(dqref[0], 12); + + dq.pop_front(); + EXPECT_TRUE(dqref.empty()); + + // Push to capacity. + dq.push_back(15); + dq.push_front(5); + dq.push_back(25); + EXPECT_EQ(dq.size(), dq.capacity()); + EXPECT_THAT(dq, ElementsAre(5, 15, 25)); + EXPECT_LT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 5); + EXPECT_EQ(dq.back(), 25); + EXPECT_EQ(dq.at(0), 5); + EXPECT_EQ(dq.at(1), 15); + EXPECT_EQ(dq.at(2), 25); + EXPECT_EQ(dq[0], 5); + EXPECT_EQ(dq[1], 15); + EXPECT_EQ(dq[2], 25); + + // Shift right such that begin=1 and end=0. Data is still not wrapped. + dq.pop_front(); + dq.push_back(35); + EXPECT_THAT(dq, ElementsAre(15, 25, 35)); + EXPECT_LT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 15); + EXPECT_EQ(dq.back(), 35); + EXPECT_EQ(dq.at(0), 15); + EXPECT_EQ(dq.at(1), 25); + EXPECT_EQ(dq.at(2), 35); + EXPECT_EQ(dq[0], 15); + EXPECT_EQ(dq[1], 25); + EXPECT_EQ(dq[2], 35); + + // Shift right such that data is wrapped. + dq.pop_front(); + dq.push_back(45); + EXPECT_THAT(dq, ElementsAre(25, 35, 45)); + EXPECT_GT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 25); + EXPECT_EQ(dq.back(), 45); + EXPECT_EQ(dq.at(0), 25); + EXPECT_EQ(dq.at(1), 35); + EXPECT_EQ(dq.at(2), 45); + EXPECT_EQ(dq[0], 25); + EXPECT_EQ(dq[1], 35); + EXPECT_EQ(dq[2], 45); + + // Shift right again, data is still wrapped. + dq.pop_front(); + dq.push_back(55); + EXPECT_THAT(dq, ElementsAre(35, 45, 55)); + EXPECT_GT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 35); + EXPECT_EQ(dq.back(), 55); + EXPECT_EQ(dq.at(0), 35); + EXPECT_EQ(dq.at(1), 45); + EXPECT_EQ(dq.at(2), 55); + EXPECT_EQ(dq[0], 35); + EXPECT_EQ(dq[1], 45); + EXPECT_EQ(dq[2], 55); + + // Shift right one last time. begin returns to 0. Data is no longer wrapped. + dq.pop_front(); + dq.push_back(65); + EXPECT_THAT(dq, ElementsAre(45, 55, 65)); + EXPECT_LT(&dq.front(), &dq.back()); + EXPECT_EQ(dq.front(), 45); + EXPECT_EQ(dq.back(), 65); + EXPECT_EQ(dq.at(0), 45); + EXPECT_EQ(dq.at(1), 55); + EXPECT_EQ(dq.at(2), 65); + EXPECT_EQ(dq[0], 45); + EXPECT_EQ(dq[1], 55); + EXPECT_EQ(dq[2], 65); + + EXPECT_EQ(1u, dq.get_allocator().allocate_count()); +} + +TEST_F(QuicheCircularDequeTest, Iterate) { + QuicheCircularDeque dq; + EXPECT_EQ(dq.begin(), dq.end()); + EXPECT_EQ(dq.cbegin(), dq.cend()); + EXPECT_EQ(dq.rbegin(), dq.rend()); + EXPECT_EQ(dq.crbegin(), dq.crend()); + + dq.emplace_back(2); + QuicheCircularDeque::const_iterator citer = dq.begin(); + EXPECT_NE(citer, dq.end()); + EXPECT_EQ(*citer, 2); + ++citer; + EXPECT_EQ(citer, dq.end()); + + EXPECT_EQ(*dq.begin(), 2); + EXPECT_EQ(*dq.cbegin(), 2); + EXPECT_EQ(*dq.rbegin(), 2); + EXPECT_EQ(*dq.crbegin(), 2); + + dq.emplace_front(1); + QuicheCircularDeque::const_reverse_iterator criter = dq.rbegin(); + EXPECT_NE(criter, dq.rend()); + EXPECT_EQ(*criter, 2); + ++criter; + EXPECT_NE(criter, dq.rend()); + EXPECT_EQ(*criter, 1); + ++criter; + EXPECT_EQ(criter, dq.rend()); + + EXPECT_EQ(*dq.begin(), 1); + EXPECT_EQ(*dq.cbegin(), 1); + EXPECT_EQ(*dq.rbegin(), 2); + EXPECT_EQ(*dq.crbegin(), 2); + + dq.push_back(3); + + // Forward iterate. + int expected_value = 1; + for (QuicheCircularDeque::iterator it = dq.begin(); it != dq.end(); + ++it) { + EXPECT_EQ(expected_value++, *it); + } + + expected_value = 1; + for (QuicheCircularDeque::const_iterator it = dq.cbegin(); + it != dq.cend(); ++it) { + EXPECT_EQ(expected_value++, *it); + } + + // Reverse iterate. + expected_value = 3; + for (QuicheCircularDeque::reverse_iterator it = dq.rbegin(); + it != dq.rend(); ++it) { + EXPECT_EQ(expected_value--, *it); + } + + expected_value = 3; + for (QuicheCircularDeque::const_reverse_iterator it = dq.crbegin(); + it != dq.crend(); ++it) { + EXPECT_EQ(expected_value--, *it); + } +} + +TEST_F(QuicheCircularDequeTest, Iterator) { + // Default constructed iterators of the same type compare equal. + EXPECT_EQ(QuicheCircularDeque::iterator(), + QuicheCircularDeque::iterator()); + EXPECT_EQ(QuicheCircularDeque::const_iterator(), + QuicheCircularDeque::const_iterator()); + EXPECT_EQ(QuicheCircularDeque::reverse_iterator(), + QuicheCircularDeque::reverse_iterator()); + EXPECT_EQ(QuicheCircularDeque::const_reverse_iterator(), + QuicheCircularDeque::const_reverse_iterator()); + + QuicheCircularDeque, 3> dqdq = { + {1, 2}, {10, 20, 30}, {100, 200, 300, 400}}; + + // iter points to {1, 2} + decltype(dqdq)::iterator iter = dqdq.begin(); + EXPECT_EQ(iter->size(), 2u); + EXPECT_THAT(*iter, ElementsAre(1, 2)); + + // citer points to {10, 20, 30} + decltype(dqdq)::const_iterator citer = dqdq.cbegin() + 1; + EXPECT_NE(*iter, *citer); + EXPECT_EQ(citer->size(), 3u); + int x = 10; + for (auto it = citer->begin(); it != citer->end(); ++it) { + EXPECT_EQ(*it, x); + x += 10; + } + + EXPECT_LT(iter, citer); + EXPECT_LE(iter, iter); + EXPECT_GT(citer, iter); + EXPECT_GE(citer, citer); + + // iter points to {100, 200, 300, 400} + iter += 2; + EXPECT_NE(*iter, *citer); + EXPECT_EQ(iter->size(), 4u); + for (int i = 1; i <= 4; ++i) { + EXPECT_EQ(iter->begin()[i - 1], i * 100); + } + + EXPECT_LT(citer, iter); + EXPECT_LE(iter, iter); + EXPECT_GT(iter, citer); + EXPECT_GE(citer, citer); + + // iter points to {10, 20, 30}. (same as citer) + iter -= 1; + EXPECT_EQ(*iter, *citer); + EXPECT_EQ(iter->size(), 3u); + x = 10; + for (auto it = iter->begin(); it != iter->end();) { + EXPECT_EQ(*(it++), x); + x += 10; + } + x = 30; + for (auto it = iter->begin() + 2; it != iter->begin();) { + EXPECT_EQ(*(it--), x); + x -= 10; + } +} + +TEST_F(QuicheCircularDequeTest, Resize) { + QuicheCircularDeque> dq; + dq.resize(8); + EXPECT_THAT(dq, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0)); + EXPECT_EQ(1u, dq.get_allocator().allocate_count()); + + dq.resize(10, 5); + EXPECT_THAT(dq, ElementsAre(0, 0, 0, 0, 0, 0, 0, 0, 5, 5)); + + QuicheCircularDeque> dq2 = dq; + + for (size_t new_size = dq.size(); new_size != 0; --new_size) { + dq.resize(new_size); + EXPECT_TRUE( + std::equal(dq.begin(), dq.end(), dq2.begin(), dq2.begin() + new_size)); + } + + dq.resize(0); + EXPECT_TRUE(dq.empty()); + + // Resize when data is wrapped. + ASSERT_EQ(dq2.size(), dq2.capacity()); + while (dq2.size() < dq2.capacity()) { + dq2.push_back(5); + } + + // Shift left once such that data is wrapped. + ASSERT_LT(&dq2.front(), &dq2.back()); + dq2.pop_back(); + dq2.push_front(-5); + ASSERT_GT(&dq2.front(), &dq2.back()); + + EXPECT_EQ(-5, dq2.front()); + EXPECT_EQ(5, dq2.back()); + dq2.resize(dq2.size() + 1, 10); + + // Data should be unwrapped after the resize. + ASSERT_LT(&dq2.front(), &dq2.back()); + EXPECT_EQ(-5, dq2.front()); + EXPECT_EQ(10, dq2.back()); + EXPECT_EQ(5, *(dq2.rbegin() + 1)); +} + +namespace { +class Foo { + public: + Foo() : Foo(0xF00) {} + + explicit Foo(int i) : i_(new int(i)) {} + + ~Foo() { + if (i_ != nullptr) { + delete i_; + // Do not set i_ to nullptr such that if the container calls destructor + // multiple times, asan can detect it. + } + } + + Foo(const Foo& other) : i_(new int(*other.i_)) {} + + Foo(Foo&& other) = delete; + + void Set(int i) { *i_ = i; } + + int i() const { return *i_; } + + friend bool operator==(const Foo& lhs, const Foo& rhs) { + return lhs.i() == rhs.i(); + } + + friend std::ostream& operator<<(std::ostream& os, const Foo& foo) { + return os << "Foo(" << foo.i() << ")"; + } + + private: + // By pointing i_ to a dynamically allocated integer, a memory leak will be + // reported if the container forget to properly destruct this object. + int* i_ = nullptr; +}; +} // namespace + +TEST_F(QuicheCircularDequeTest, RelocateNonTriviallyCopyable) { + // When relocating non-trivially-copyable objects: + // - Move constructor is preferred, if available. + // - Copy constructor is used otherwise. + + { + // Move construct in Relocate. + using MoveConstructible = std::unique_ptr; + ASSERT_FALSE(std::is_trivially_copyable::value); + ASSERT_TRUE(std::is_move_constructible::value); + QuicheCircularDeque> + dq1; + dq1.resize(3); + EXPECT_EQ(dq1.size(), dq1.capacity()); + EXPECT_EQ(1u, dq1.get_allocator().allocate_count()); + + dq1.emplace_back(new Foo(0xF1)); // Cause existing elements to relocate. + EXPECT_EQ(4u, dq1.size()); + EXPECT_EQ(2u, dq1.get_allocator().allocate_count()); + EXPECT_EQ(dq1[0], nullptr); + EXPECT_EQ(dq1[1], nullptr); + EXPECT_EQ(dq1[2], nullptr); + EXPECT_EQ(dq1[3]->i(), 0xF1); + } + + { + // Copy construct in Relocate. + using NonMoveConstructible = Foo; + ASSERT_FALSE(std::is_trivially_copyable::value); + ASSERT_FALSE(std::is_move_constructible::value); + QuicheCircularDeque> + dq2; + dq2.resize(3); + EXPECT_EQ(dq2.size(), dq2.capacity()); + EXPECT_EQ(1u, dq2.get_allocator().allocate_count()); + + dq2.emplace_back(0xF1); // Cause existing elements to relocate. + EXPECT_EQ(4u, dq2.size()); + EXPECT_EQ(2u, dq2.get_allocator().allocate_count()); + EXPECT_EQ(dq2[0].i(), 0xF00); + EXPECT_EQ(dq2[1].i(), 0xF00); + EXPECT_EQ(dq2[2].i(), 0xF00); + EXPECT_EQ(dq2[3].i(), 0xF1); + } +} + +TEST_F(QuicheCircularDequeTest, PushPop) { + // (push|pop|emplace)_(back|front) + + { + QuicheCircularDeque> dq(4); + for (size_t i = 0; i < dq.size(); ++i) { + dq[i].Set(i + 1); + } + QUICHE_LOG(INFO) << "dq initialized to " << dq; + EXPECT_THAT(dq, ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4))); + + ShiftLeft(&dq, false); + QUICHE_LOG(INFO) << "shift left once : " << dq; + EXPECT_THAT(dq, ElementsAre(Foo(2), Foo(3), Foo(4), Foo(1))); + + ShiftLeft(&dq, true); + QUICHE_LOG(INFO) << "shift left twice: " << dq; + EXPECT_THAT(dq, ElementsAre(Foo(3), Foo(4), Foo(1), Foo(2))); + ASSERT_GT(&dq.front(), &dq.back()); + // dq destructs with wrapped data. + } + + { + QuicheCircularDeque> dq1(4); + for (size_t i = 0; i < dq1.size(); ++i) { + dq1[i].Set(i + 1); + } + QUICHE_LOG(INFO) << "dq1 initialized to " << dq1; + EXPECT_THAT(dq1, ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4))); + + ShiftRight(&dq1, false); + QUICHE_LOG(INFO) << "shift right once : " << dq1; + EXPECT_THAT(dq1, ElementsAre(Foo(4), Foo(1), Foo(2), Foo(3))); + + ShiftRight(&dq1, true); + QUICHE_LOG(INFO) << "shift right twice: " << dq1; + EXPECT_THAT(dq1, ElementsAre(Foo(3), Foo(4), Foo(1), Foo(2))); + ASSERT_GT(&dq1.front(), &dq1.back()); + // dq1 destructs with wrapped data. + } + + { // Pop n elements from front. + QuicheCircularDeque> dq2(5); + for (size_t i = 0; i < dq2.size(); ++i) { + dq2[i].Set(i + 1); + } + EXPECT_THAT(dq2, ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4), Foo(5))); + + EXPECT_EQ(2u, dq2.pop_front_n(2)); + EXPECT_THAT(dq2, ElementsAre(Foo(3), Foo(4), Foo(5))); + + EXPECT_EQ(3u, dq2.pop_front_n(100)); + EXPECT_TRUE(dq2.empty()); + } + + { // Pop n elements from back. + QuicheCircularDeque> dq3(6); + for (size_t i = 0; i < dq3.size(); ++i) { + dq3[i].Set(i + 1); + } + EXPECT_THAT(dq3, + ElementsAre(Foo(1), Foo(2), Foo(3), Foo(4), Foo(5), Foo(6))); + + ShiftRight(&dq3, true); + ShiftRight(&dq3, true); + ShiftRight(&dq3, true); + EXPECT_THAT(dq3, + ElementsAre(Foo(4), Foo(5), Foo(6), Foo(1), Foo(2), Foo(3))); + + EXPECT_EQ(2u, dq3.pop_back_n(2)); + EXPECT_THAT(dq3, ElementsAre(Foo(4), Foo(5), Foo(6), Foo(1))); + + EXPECT_EQ(2u, dq3.pop_back_n(2)); + EXPECT_THAT(dq3, ElementsAre(Foo(4), Foo(5))); + } +} + +TEST_F(QuicheCircularDequeTest, Allocation) { + CountingAllocator alloc; + + { + QuicheCircularDeque> dq(alloc); + EXPECT_EQ(alloc, dq.get_allocator()); + EXPECT_EQ(0u, dq.size()); + EXPECT_EQ(0u, dq.capacity()); + EXPECT_EQ(0u, alloc.allocate_count()); + EXPECT_EQ(0u, alloc.deallocate_count()); + + for (int i = 1; i <= 18; ++i) { + SCOPED_TRACE(testing::Message() + << "i=" << i << ", capacity_b4_push=" << dq.capacity()); + dq.push_back(i); + EXPECT_EQ(i, static_cast(dq.size())); + + const size_t capacity = 3 + (i - 1) / 3 * 3; + EXPECT_EQ(capacity, dq.capacity()); + EXPECT_EQ(capacity / 3, alloc.allocate_count()); + EXPECT_EQ(capacity / 3 - 1, alloc.deallocate_count()); + } + + dq.push_back(19); + EXPECT_EQ(22u, dq.capacity()); // 18 + 18 / 4 + EXPECT_EQ(7u, alloc.allocate_count()); + EXPECT_EQ(6u, alloc.deallocate_count()); + } + + EXPECT_EQ(7u, alloc.deallocate_count()); +} + +} // namespace +} // namespace test +} // namespace quiche + +// Use a non-quiche namespace to make sure swap can be used via ADL. +namespace { + +template +using SwappableAllocator = quiche::test::ConfigurableAllocator< + T, + /*propagate_on_copy_assignment=*/std::true_type, + /*propagate_on_move_assignment=*/std::true_type, + /*propagate_on_swap=*/std::true_type, + /*equality_result=*/true>; + +template +using UnswappableEqualAllocator = quiche::test::ConfigurableAllocator< + T, + /*propagate_on_copy_assignment=*/std::true_type, + /*propagate_on_move_assignment=*/std::true_type, + /*propagate_on_swap=*/std::false_type, + /*equality_result=*/true>; + +template +using UnswappableUnequalAllocator = quiche::test::ConfigurableAllocator< + T, + /*propagate_on_copy_assignment=*/std::true_type, + /*propagate_on_move_assignment=*/std::true_type, + /*propagate_on_swap=*/std::false_type, + /*equality_result=*/false>; + +using quiche::test::QuicheCircularDequeTest; + +TEST_F(QuicheCircularDequeTest, Swap) { + using std::swap; + + quiche::QuicheCircularDeque> dq1, dq2; + dq1.push_back(10); + dq1.push_back(11); + dq2.push_back(20); + swap(dq1, dq2); + EXPECT_THAT(dq1, ElementsAre(20)); + EXPECT_THAT(dq2, ElementsAre(10, 11)); + + quiche::QuicheCircularDeque> dq3, + dq4; + dq3 = {1, 2, 3, 4, 5}; + dq4 = {6, 7, 8, 9, 0}; + swap(dq3, dq4); + EXPECT_THAT(dq3, ElementsAre(6, 7, 8, 9, 0)); + EXPECT_THAT(dq4, ElementsAre(1, 2, 3, 4, 5)); + + quiche::QuicheCircularDeque> dq5, + dq6; + dq6.push_front(4); + + // Using UnswappableUnequalAllocator is ok as long as swap is not called. + dq5.assign(dq6.begin(), dq6.end()); + EXPECT_THAT(dq5, ElementsAre(4)); + + // Undefined behavior to swap between two containers with unequal allocators. + EXPECT_QUICHE_DEBUG_DEATH(swap(dq5, dq6), "Undefined swap behavior"); +} +} // namespace diff --git a/gquiche/common/quiche_data_reader.cc b/gquiche/common/quiche_data_reader.cc index 8895647a..37de025c 100644 --- a/gquiche/common/quiche_data_reader.cc +++ b/gquiche/common/quiche_data_reader.cc @@ -10,7 +10,6 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "gquiche/common/platform/api/quiche_logging.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/quiche_endian.h" namespace quiche { diff --git a/gquiche/common/quiche_linked_hash_map.h b/gquiche/common/quiche_linked_hash_map.h new file mode 100644 index 00000000..fc96b77d --- /dev/null +++ b/gquiche/common/quiche_linked_hash_map.h @@ -0,0 +1,237 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This is a simplistic insertion-ordered map. It behaves similarly to an STL +// map, but only implements a small subset of the map's methods. Internally, we +// just keep a map and a list going in parallel. +// +// This class provides no thread safety guarantees, beyond what you would +// normally see with std::list. +// +// Iterators point into the list and should be stable in the face of +// mutations, except for an iterator pointing to an element that was just +// deleted. + +#ifndef QUICHE_COMMON_QUICHE_LINKED_HASH_MAP_H_ +#define QUICHE_COMMON_QUICHE_LINKED_HASH_MAP_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/node_hash_map.h" +#include "absl/hash/hash.h" +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/platform/api/quiche_logging.h" + +namespace quiche { + +// This holds a list of pair items. This list is what gets +// traversed, and it's iterators from this list that we return from +// begin/end/find. +// +// We also keep a set for find. Since std::list is a +// doubly-linked list, the iterators should remain stable. + +// QUICHE_NO_EXPORT comments suppress erroneous presubmit failures. +template , // QUICHE_NO_EXPORT + class Eq = std::equal_to> // QUICHE_NO_EXPORT +class QuicheLinkedHashMap { // QUICHE_NO_EXPORT + private: + typedef std::list> ListType; + typedef absl::node_hash_map + MapType; + + public: + typedef typename ListType::iterator iterator; + typedef typename ListType::reverse_iterator reverse_iterator; + typedef typename ListType::const_iterator const_iterator; + typedef typename ListType::const_reverse_iterator const_reverse_iterator; + typedef typename MapType::key_type key_type; + typedef typename ListType::value_type value_type; + typedef typename ListType::size_type size_type; + + QuicheLinkedHashMap() = default; + explicit QuicheLinkedHashMap(size_type bucket_count) : map_(bucket_count) {} + + QuicheLinkedHashMap(const QuicheLinkedHashMap& other) = delete; + QuicheLinkedHashMap& operator=(const QuicheLinkedHashMap& other) = delete; + QuicheLinkedHashMap(QuicheLinkedHashMap&& other) = default; + QuicheLinkedHashMap& operator=(QuicheLinkedHashMap&& other) = default; + + // Returns an iterator to the first (insertion-ordered) element. Like a map, + // this can be dereferenced to a pair. + iterator begin() { return list_.begin(); } + const_iterator begin() const { return list_.begin(); } + + // Returns an iterator beyond the last element. + iterator end() { return list_.end(); } + const_iterator end() const { return list_.end(); } + + // Returns an iterator to the last (insertion-ordered) element. Like a map, + // this can be dereferenced to a pair. + reverse_iterator rbegin() { return list_.rbegin(); } + const_reverse_iterator rbegin() const { return list_.rbegin(); } + + // Returns an iterator beyond the first element. + reverse_iterator rend() { return list_.rend(); } + const_reverse_iterator rend() const { return list_.rend(); } + + // Front and back accessors common to many stl containers. + + // Returns the earliest-inserted element + const value_type& front() const { return list_.front(); } + + // Returns the earliest-inserted element. + value_type& front() { return list_.front(); } + + // Returns the most-recently-inserted element. + const value_type& back() const { return list_.back(); } + + // Returns the most-recently-inserted element. + value_type& back() { return list_.back(); } + + // Clears the map of all values. + void clear() { + map_.clear(); + list_.clear(); + } + + // Returns true iff the map is empty. + bool empty() const { return list_.empty(); } + + // Removes the first element from the list. + void pop_front() { erase(begin()); } + + // Erases values with the provided key. Returns the number of elements + // erased. In this implementation, this will be 0 or 1. + size_type erase(const Key& key) { + typename MapType::iterator found = map_.find(key); + if (found == map_.end()) { + return 0; + } + + list_.erase(found->second); + map_.erase(found); + + return 1; + } + + // Erases the item that 'position' points to. Returns an iterator that points + // to the item that comes immediately after the deleted item in the list, or + // end(). + // If the provided iterator is invalid or there is inconsistency between the + // map and list, a QUICHE_CHECK() error will occur. + iterator erase(iterator position) { + typename MapType::iterator found = map_.find(position->first); + QUICHE_CHECK(found->second == position) + << "Inconsistent iterator for map and list, or the iterator is " + "invalid."; + + map_.erase(found); + return list_.erase(position); + } + + // Erases all the items in the range [first, last). Returns an iterator that + // points to the item that comes immediately after the last deleted item in + // the list, or end(). + iterator erase(iterator first, iterator last) { + while (first != last && first != end()) { + first = erase(first); + } + return first; + } + + // Finds the element with the given key. Returns an iterator to the + // value found, or to end() if the value was not found. Like a map, this + // iterator points to a pair. + iterator find(const Key& key) { + typename MapType::iterator found = map_.find(key); + if (found == map_.end()) { + return end(); + } + return found->second; + } + + const_iterator find(const Key& key) const { + typename MapType::const_iterator found = map_.find(key); + if (found == map_.end()) { + return end(); + } + return found->second; + } + + bool contains(const Key& key) const { return find(key) != end(); } + + // Returns the value mapped to key, or an inserted iterator to that position + // in the map. + Value& operator[](const key_type& key) { + return (*((this->insert(std::make_pair(key, Value()))).first)).second; + } + + // Inserts an element into the map + std::pair insert(const std::pair& pair) { + return InsertInternal(pair); + } + + // Inserts an element into the map + std::pair insert(std::pair&& pair) { + return InsertInternal(std::move(pair)); + } + + // Derive size_ from map_, as list::size might be O(N). + size_type size() const { return map_.size(); } + + template + std::pair emplace(Args&&... args) { + ListType node_donor; + auto node_pos = + node_donor.emplace(node_donor.end(), std::forward(args)...); + const auto& k = node_pos->first; + auto ins = map_.insert({k, node_pos}); + if (!ins.second) { + return {ins.first->second, false}; + } + list_.splice(list_.end(), node_donor, node_pos); + return {ins.first->second, true}; + } + + void swap(QuicheLinkedHashMap& other) { + map_.swap(other.map_); + list_.swap(other.list_); + } + + private: + template + std::pair InsertInternal(U&& pair) { + auto insert_result = map_.try_emplace(pair.first); + auto map_iter = insert_result.first; + + // If the map already contains this key, return a pair with an iterator to + // it, and false indicating that we didn't insert anything. + if (!insert_result.second) { + return {map_iter->second, false}; + } + + // Otherwise, insert into the list, and set value in map. + auto list_iter = list_.insert(list_.end(), std::forward(pair)); + map_iter->second = list_iter; + + return {list_iter, true}; + } + + // The map component, used for speedy lookups + MapType map_; + + // The list component, used for maintaining insertion order + ListType list_; +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_LINKED_HASH_MAP_H_ diff --git a/gquiche/common/quiche_linked_hash_map_test.cc b/gquiche/common/quiche_linked_hash_map_test.cc new file mode 100644 index 00000000..e2643a0e --- /dev/null +++ b/gquiche/common/quiche_linked_hash_map_test.cc @@ -0,0 +1,395 @@ +// Copyright (c) 2019 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Tests QuicheLinkedHashMap. + +#include "gquiche/common/quiche_linked_hash_map.h" + +#include +#include + +#include "gquiche/common/platform/api/quiche_test.h" + +using testing::Pair; +using testing::Pointee; +using testing::UnorderedElementsAre; + +namespace quiche { +namespace test { + +// Tests that move constructor works. +TEST(LinkedHashMapTest, Move) { + // Use unique_ptr as an example of a non-copyable type. + QuicheLinkedHashMap> m; + m[2] = std::make_unique(12); + m[3] = std::make_unique(13); + QuicheLinkedHashMap> n = std::move(m); + EXPECT_THAT(n, + UnorderedElementsAre(Pair(2, Pointee(12)), Pair(3, Pointee(13)))); +} + +TEST(LinkedHashMapTest, CanEmplaceMoveOnly) { + QuicheLinkedHashMap> m; + struct Data { + int k, v; + }; + const Data data[] = {{1, 123}, {3, 345}, {2, 234}, {4, 456}}; + for (const auto& kv : data) { + m.emplace(std::piecewise_construct, std::make_tuple(kv.k), + std::make_tuple(new int{kv.v})); + } + EXPECT_TRUE(m.contains(2)); + auto found = m.find(2); + ASSERT_TRUE(found != m.end()); + EXPECT_EQ(234, *found->second); +} + +struct NoCopy { + explicit NoCopy(int x) : x(x) {} + NoCopy(const NoCopy&) = delete; + NoCopy& operator=(const NoCopy&) = delete; + NoCopy(NoCopy&&) = delete; + NoCopy& operator=(NoCopy&&) = delete; + int x; +}; + +TEST(LinkedHashMapTest, CanEmplaceNoMoveNoCopy) { + QuicheLinkedHashMap m; + struct Data { + int k, v; + }; + const Data data[] = {{1, 123}, {3, 345}, {2, 234}, {4, 456}}; + for (const auto& kv : data) { + m.emplace(std::piecewise_construct, std::make_tuple(kv.k), + std::make_tuple(kv.v)); + } + EXPECT_TRUE(m.contains(2)); + auto found = m.find(2); + ASSERT_TRUE(found != m.end()); + EXPECT_EQ(234, found->second.x); +} + +TEST(LinkedHashMapTest, ConstKeys) { + QuicheLinkedHashMap m; + m.insert(std::make_pair(1, 2)); + // Test that keys are const in iteration. + std::pair& p = *m.begin(); + EXPECT_EQ(1, p.first); +} + +// Tests that iteration from begin() to end() works +TEST(LinkedHashMapTest, Iteration) { + QuicheLinkedHashMap m; + EXPECT_TRUE(m.begin() == m.end()); + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + QuicheLinkedHashMap::iterator i = m.begin(); + ASSERT_TRUE(m.begin() == i); + ASSERT_TRUE(m.end() != i); + EXPECT_EQ(2, i->first); + EXPECT_EQ(12, i->second); + + ++i; + ASSERT_TRUE(m.end() != i); + EXPECT_EQ(1, i->first); + EXPECT_EQ(11, i->second); + + ++i; + ASSERT_TRUE(m.end() != i); + EXPECT_EQ(3, i->first); + EXPECT_EQ(13, i->second); + + ++i; // Should be the end of the line. + ASSERT_TRUE(m.end() == i); +} + +// Tests that reverse iteration from rbegin() to rend() works +TEST(LinkedHashMapTest, ReverseIteration) { + QuicheLinkedHashMap m; + EXPECT_TRUE(m.rbegin() == m.rend()); + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + QuicheLinkedHashMap::reverse_iterator i = m.rbegin(); + ASSERT_TRUE(m.rbegin() == i); + ASSERT_TRUE(m.rend() != i); + EXPECT_EQ(3, i->first); + EXPECT_EQ(13, i->second); + + ++i; + ASSERT_TRUE(m.rend() != i); + EXPECT_EQ(1, i->first); + EXPECT_EQ(11, i->second); + + ++i; + ASSERT_TRUE(m.rend() != i); + EXPECT_EQ(2, i->first); + EXPECT_EQ(12, i->second); + + ++i; // Should be the end of the line. + ASSERT_TRUE(m.rend() == i); +} + +// Tests that clear() works +TEST(LinkedHashMapTest, Clear) { + QuicheLinkedHashMap m; + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + ASSERT_EQ(3u, m.size()); + + m.clear(); + + EXPECT_EQ(0u, m.size()); + + m.clear(); // Make sure we can call it on an empty map. + + EXPECT_EQ(0u, m.size()); +} + +// Tests that size() works. +TEST(LinkedHashMapTest, Size) { + QuicheLinkedHashMap m; + EXPECT_EQ(0u, m.size()); + m.insert(std::make_pair(2, 12)); + EXPECT_EQ(1u, m.size()); + m.insert(std::make_pair(1, 11)); + EXPECT_EQ(2u, m.size()); + m.insert(std::make_pair(3, 13)); + EXPECT_EQ(3u, m.size()); + m.clear(); + EXPECT_EQ(0u, m.size()); +} + +// Tests empty() +TEST(LinkedHashMapTest, Empty) { + QuicheLinkedHashMap m; + ASSERT_TRUE(m.empty()); + m.insert(std::make_pair(2, 12)); + ASSERT_FALSE(m.empty()); + m.clear(); + ASSERT_TRUE(m.empty()); +} + +TEST(LinkedHashMapTest, Erase) { + QuicheLinkedHashMap m; + ASSERT_EQ(0u, m.size()); + EXPECT_EQ(0u, m.erase(2)); // Nothing to erase yet + + m.insert(std::make_pair(2, 12)); + ASSERT_EQ(1u, m.size()); + EXPECT_EQ(1u, m.erase(2)); + EXPECT_EQ(0u, m.size()); + + EXPECT_EQ(0u, m.erase(2)); // Make sure nothing bad happens if we repeat. + EXPECT_EQ(0u, m.size()); +} + +TEST(LinkedHashMapTest, Erase2) { + QuicheLinkedHashMap m; + ASSERT_EQ(0u, m.size()); + EXPECT_EQ(0u, m.erase(2)); // Nothing to erase yet + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + m.insert(std::make_pair(4, 14)); + ASSERT_EQ(4u, m.size()); + + // Erase middle two + EXPECT_EQ(1u, m.erase(1)); + EXPECT_EQ(1u, m.erase(3)); + + EXPECT_EQ(2u, m.size()); + + // Make sure we can still iterate over everything that's left. + QuicheLinkedHashMap::iterator it = m.begin(); + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(12, it->second); + ++it; + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(14, it->second); + ++it; + ASSERT_TRUE(it == m.end()); + + EXPECT_EQ(0u, m.erase(1)); // Make sure nothing bad happens if we repeat. + ASSERT_EQ(2u, m.size()); + + EXPECT_EQ(1u, m.erase(2)); + EXPECT_EQ(1u, m.erase(4)); + ASSERT_EQ(0u, m.size()); + + EXPECT_EQ(0u, m.erase(1)); // Make sure nothing bad happens if we repeat. + ASSERT_EQ(0u, m.size()); +} + +// Test that erase(iter,iter) and erase(iter) compile and work. +TEST(LinkedHashMapTest, Erase3) { + QuicheLinkedHashMap m; + + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(3, 13)); + m.insert(std::make_pair(4, 14)); + + // Erase middle two + QuicheLinkedHashMap::iterator it2 = m.find(2); + QuicheLinkedHashMap::iterator it4 = m.find(4); + EXPECT_EQ(m.erase(it2, it4), m.find(4)); + EXPECT_EQ(2u, m.size()); + + // Make sure we can still iterate over everything that's left. + QuicheLinkedHashMap::iterator it = m.begin(); + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(11, it->second); + ++it; + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(14, it->second); + ++it; + ASSERT_TRUE(it == m.end()); + + // Erase first one using an iterator. + EXPECT_EQ(m.erase(m.begin()), m.find(4)); + + // Only the last element should be left. + it = m.begin(); + ASSERT_TRUE(it != m.end()); + EXPECT_EQ(14, it->second); + ++it; + ASSERT_TRUE(it == m.end()); +} + +TEST(LinkedHashMapTest, Insertion) { + QuicheLinkedHashMap m; + ASSERT_EQ(0u, m.size()); + std::pair::iterator, bool> result; + + result = m.insert(std::make_pair(2, 12)); + ASSERT_EQ(1u, m.size()); + EXPECT_TRUE(result.second); + EXPECT_EQ(2, result.first->first); + EXPECT_EQ(12, result.first->second); + + result = m.insert(std::make_pair(1, 11)); + ASSERT_EQ(2u, m.size()); + EXPECT_TRUE(result.second); + EXPECT_EQ(1, result.first->first); + EXPECT_EQ(11, result.first->second); + + result = m.insert(std::make_pair(3, 13)); + QuicheLinkedHashMap::iterator result_iterator = result.first; + ASSERT_EQ(3u, m.size()); + EXPECT_TRUE(result.second); + EXPECT_EQ(3, result.first->first); + EXPECT_EQ(13, result.first->second); + + result = m.insert(std::make_pair(3, 13)); + EXPECT_EQ(3u, m.size()); + EXPECT_FALSE(result.second) << "No insertion should have occurred."; + EXPECT_TRUE(result_iterator == result.first) + << "Duplicate insertion should have given us the original iterator."; +} + +static std::pair Pair(int i, int j) { + return {i, j}; +} + +// Test front accessors. +TEST(LinkedHashMapTest, Front) { + QuicheLinkedHashMap m; + + m.insert(std::make_pair(2, 12)); + m.insert(std::make_pair(1, 11)); + m.insert(std::make_pair(3, 13)); + + EXPECT_EQ(3u, m.size()); + EXPECT_EQ(Pair(2, 12), m.front()); + m.pop_front(); + EXPECT_EQ(2u, m.size()); + EXPECT_EQ(Pair(1, 11), m.front()); + m.pop_front(); + EXPECT_EQ(1u, m.size()); + EXPECT_EQ(Pair(3, 13), m.front()); + m.pop_front(); + EXPECT_TRUE(m.empty()); +} + +TEST(LinkedHashMapTest, Find) { + QuicheLinkedHashMap m; + + EXPECT_TRUE(m.end() == m.find(1)) + << "We shouldn't find anything in an empty map."; + + m.insert(std::make_pair(2, 12)); + EXPECT_TRUE(m.end() == m.find(1)) + << "We shouldn't find an element that doesn't exist in the map."; + + std::pair::iterator, bool> result = + m.insert(std::make_pair(1, 11)); + ASSERT_TRUE(result.second); + ASSERT_TRUE(m.end() != result.first); + EXPECT_TRUE(result.first == m.find(1)) + << "We should have found an element we know exists in the map."; + EXPECT_EQ(11, result.first->second); + + // Check that a follow-up insertion doesn't affect our original + m.insert(std::make_pair(3, 13)); + QuicheLinkedHashMap::iterator it = m.find(1); + ASSERT_TRUE(m.end() != it); + EXPECT_EQ(11, it->second); + + m.clear(); + EXPECT_TRUE(m.end() == m.find(1)) + << "We shouldn't find anything in a map that we've cleared."; +} + +TEST(LinkedHashMapTest, Contains) { + QuicheLinkedHashMap m; + + EXPECT_FALSE(m.contains(1)) << "An empty map shouldn't contain anything."; + + m.insert(std::make_pair(2, 12)); + EXPECT_FALSE(m.contains(1)) + << "The map shouldn't contain an element that doesn't exist."; + + m.insert(std::make_pair(1, 11)); + EXPECT_TRUE(m.contains(1)) + << "The map should contain an element that we know exists."; + + m.clear(); + EXPECT_FALSE(m.contains(1)) + << "A map that we've cleared shouldn't contain anything."; +} + +TEST(LinkedHashMapTest, Swap) { + QuicheLinkedHashMap m1; + QuicheLinkedHashMap m2; + m1.insert(std::make_pair(1, 1)); + m1.insert(std::make_pair(2, 2)); + m2.insert(std::make_pair(3, 3)); + ASSERT_EQ(2u, m1.size()); + ASSERT_EQ(1u, m2.size()); + m1.swap(m2); + ASSERT_EQ(1u, m1.size()); + ASSERT_EQ(2u, m2.size()); +} + +TEST(LinkedHashMapTest, CustomHashAndEquality) { + struct CustomIntHash { + size_t operator()(int x) const { return x; } + }; + QuicheLinkedHashMap m; + m.insert(std::make_pair(1, 1)); + EXPECT_TRUE(m.contains(1)); + EXPECT_EQ(1, m[1]); +} + +} // namespace test +} // namespace quiche diff --git a/gquiche/common/quiche_text_utils.cc b/gquiche/common/quiche_text_utils.cc new file mode 100644 index 00000000..5cb7cbbe --- /dev/null +++ b/gquiche/common/quiche_text_utils.cc @@ -0,0 +1,77 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/common/quiche_text_utils.h" + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" + +namespace quiche { + +// static +void QuicheTextUtils::Base64Encode(const uint8_t* data, + size_t data_len, + std::string* output) { + absl::Base64Escape(std::string(reinterpret_cast(data), data_len), + output); + // Remove padding. + size_t len = output->size(); + if (len >= 2) { + if ((*output)[len - 1] == '=') { + len--; + if ((*output)[len - 1] == '=') { + len--; + } + output->resize(len); + } + } +} + +// static +absl::optional QuicheTextUtils::Base64Decode( + absl::string_view input) { + std::string output; + if (!absl::Base64Unescape(input, &output)) { + return absl::nullopt; + } + return output; +} + +// static +std::string QuicheTextUtils::HexDump(absl::string_view binary_data) { + const int kBytesPerLine = 16; // Maximum bytes dumped per line. + int offset = 0; + const char* p = binary_data.data(); + int bytes_remaining = binary_data.size(); + std::string output; + while (bytes_remaining > 0) { + const int line_bytes = std::min(bytes_remaining, kBytesPerLine); + absl::StrAppendFormat(&output, "0x%04x: ", offset); + for (int i = 0; i < kBytesPerLine; ++i) { + if (i < line_bytes) { + absl::StrAppendFormat(&output, "%02x", + static_cast(p[i])); + } else { + absl::StrAppend(&output, " "); + } + if (i % 2) { + absl::StrAppend(&output, " "); + } + } + absl::StrAppend(&output, " "); + for (int i = 0; i < line_bytes; ++i) { + // Replace non-printable characters and 0x20 (space) with '.' + output += absl::ascii_isgraph(p[i]) ? p[i] : '.'; + } + + bytes_remaining -= line_bytes; + offset += line_bytes; + p += line_bytes; + absl::StrAppend(&output, "\n"); + } + return output; +} + +} // namespace quiche diff --git a/gquiche/common/quiche_text_utils.h b/gquiche/common/quiche_text_utils.h new file mode 100644 index 00000000..6296ef3f --- /dev/null +++ b/gquiche/common/quiche_text_utils.h @@ -0,0 +1,76 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_COMMON_QUICHE_TEXT_UTILS_H_ +#define QUICHE_COMMON_QUICHE_TEXT_UTILS_H_ + +#include + +#include "absl/hash/hash.h" +#include "absl/strings/ascii.h" +#include "absl/strings/escaping.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "gquiche/common/platform/api/quiche_export.h" + +namespace quiche { + +struct QUICHE_EXPORT_PRIVATE StringPieceCaseHash { + size_t operator()(absl::string_view data) const { + std::string lower = absl::AsciiStrToLower(data); + absl::Hash hasher; + return hasher(lower); + } +}; + +struct QUICHE_EXPORT_PRIVATE StringPieceCaseEqual { + bool operator()(absl::string_view piece1, absl::string_view piece2) const { + return absl::EqualsIgnoreCase(piece1, piece2); + } +}; + +// Various utilities for manipulating text. +class QUICHE_EXPORT_PRIVATE QuicheTextUtils { + public: + // Returns a new string in which |data| has been converted to lower case. + static std::string ToLower(absl::string_view data) { + return absl::AsciiStrToLower(data); + } + + // Removes leading and trailing whitespace from |data|. + static void RemoveLeadingAndTrailingWhitespace(absl::string_view* data) { + *data = absl::StripAsciiWhitespace(*data); + } + + // Base64 encodes with no padding |data_len| bytes of |data| into |output|. + static void Base64Encode(const uint8_t* data, + size_t data_len, + std::string* output); + + // Decodes a base64-encoded |input|. Returns nullopt when the input is + // invalid. + static absl::optional Base64Decode(absl::string_view input); + + // Returns a string containing hex and ASCII representations of |binary|, + // side-by-side in the style of hexdump. Non-printable characters will be + // printed as '.' in the ASCII output. + // For example, given the input "Hello, QUIC!\01\02\03\04", returns: + // "0x0000: 4865 6c6c 6f2c 2051 5549 4321 0102 0304 Hello,.QUIC!...." + static std::string HexDump(absl::string_view binary_data); + + // Returns true if |data| contains any uppercase characters. + static bool ContainsUpperCase(absl::string_view data) { + return std::any_of(data.begin(), data.end(), absl::ascii_isupper); + } + + // Returns true if |data| contains only decimal digits. + static bool IsAllDigits(absl::string_view data) { + return std::all_of(data.begin(), data.end(), absl::ascii_isdigit); + } +}; + +} // namespace quiche + +#endif // QUICHE_COMMON_QUICHE_TEXT_UTILS_H_ diff --git a/gquiche/common/quiche_text_utils_test.cc b/gquiche/common/quiche_text_utils_test.cc new file mode 100644 index 00000000..baae3d72 --- /dev/null +++ b/gquiche/common/quiche_text_utils_test.cc @@ -0,0 +1,92 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/common/quiche_text_utils.h" + +#include + +#include "absl/strings/escaping.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace quiche { +namespace test { + +TEST(QuicheTextUtilsTest, ToLower) { + EXPECT_EQ("lower", quiche::QuicheTextUtils::ToLower("LOWER")); + EXPECT_EQ("lower", quiche::QuicheTextUtils::ToLower("lower")); + EXPECT_EQ("lower", quiche::QuicheTextUtils::ToLower("lOwEr")); + EXPECT_EQ("123", quiche::QuicheTextUtils::ToLower("123")); + EXPECT_EQ("", quiche::QuicheTextUtils::ToLower("")); +} + +TEST(QuicheTextUtilsTest, RemoveLeadingAndTrailingWhitespace) { + std::string input; + + for (auto* input : {"text", " text", " text", "text ", "text ", " text ", + " text ", "\r\n\ttext", "text\n\r\t"}) { + absl::string_view piece(input); + quiche::QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&piece); + EXPECT_EQ("text", piece); + } +} + +TEST(QuicheTextUtilsTest, HexDump) { + // Verify output for empty input. + EXPECT_EQ("", quiche::QuicheTextUtils::HexDump(absl::HexStringToBytes(""))); + // Verify output of the HexDump method is as expected. + char packet[] = { + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x51, 0x55, 0x49, 0x43, 0x21, + 0x20, 0x54, 0x68, 0x69, 0x73, 0x20, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, + 0x20, 0x73, 0x68, 0x6f, 0x75, 0x6c, 0x64, 0x20, 0x62, 0x65, 0x20, 0x6c, + 0x6f, 0x6e, 0x67, 0x20, 0x65, 0x6e, 0x6f, 0x75, 0x67, 0x68, 0x20, 0x74, + 0x6f, 0x20, 0x73, 0x70, 0x61, 0x6e, 0x20, 0x6d, 0x75, 0x6c, 0x74, 0x69, + 0x70, 0x6c, 0x65, 0x20, 0x6c, 0x69, 0x6e, 0x65, 0x73, 0x20, 0x6f, 0x66, + 0x20, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x2e, 0x01, 0x02, 0x03, 0x00, + }; + EXPECT_EQ( + quiche::QuicheTextUtils::HexDump(packet), + "0x0000: 4865 6c6c 6f2c 2051 5549 4321 2054 6869 Hello,.QUIC!.Thi\n" + "0x0010: 7320 7374 7269 6e67 2073 686f 756c 6420 s.string.should.\n" + "0x0020: 6265 206c 6f6e 6720 656e 6f75 6768 2074 be.long.enough.t\n" + "0x0030: 6f20 7370 616e 206d 756c 7469 706c 6520 o.span.multiple.\n" + "0x0040: 6c69 6e65 7320 6f66 206f 7574 7075 742e lines.of.output.\n" + "0x0050: 0102 03 ...\n"); + // Verify that 0x21 and 0x7e are printable, 0x20 and 0x7f are not. + EXPECT_EQ( + "0x0000: 2021 7e7f .!~.\n", + quiche::QuicheTextUtils::HexDump(absl::HexStringToBytes("20217e7f"))); + // Verify that values above numeric_limits::max() are formatted + // properly on platforms where char is unsigned. + EXPECT_EQ("0x0000: 90aa ff ...\n", + quiche::QuicheTextUtils::HexDump(absl::HexStringToBytes("90aaff"))); +} + +TEST(QuicheTextUtilsTest, Base64Encode) { + std::string output; + std::string input = "Hello"; + quiche::QuicheTextUtils::Base64Encode( + reinterpret_cast(input.data()), input.length(), &output); + EXPECT_EQ("SGVsbG8", output); + + input = + "Hello, QUIC! This string should be long enough to span" + "multiple lines of output\n"; + quiche::QuicheTextUtils::Base64Encode( + reinterpret_cast(input.data()), input.length(), &output); + EXPECT_EQ( + "SGVsbG8sIFFVSUMhIFRoaXMgc3RyaW5nIHNob3VsZCBiZSBsb25n" + "IGVub3VnaCB0byBzcGFubXVsdGlwbGUgbGluZXMgb2Ygb3V0cHV0Cg", + output); +} + +TEST(QuicheTextUtilsTest, ContainsUpperCase) { + EXPECT_FALSE(quiche::QuicheTextUtils::ContainsUpperCase("abc")); + EXPECT_FALSE(quiche::QuicheTextUtils::ContainsUpperCase("")); + EXPECT_FALSE(quiche::QuicheTextUtils::ContainsUpperCase("123")); + EXPECT_TRUE(quiche::QuicheTextUtils::ContainsUpperCase("ABC")); + EXPECT_TRUE(quiche::QuicheTextUtils::ContainsUpperCase("aBc")); +} + +} // namespace test +} // namespace quiche diff --git a/gquiche/epoll_server/fake_simple_epoll_server.cc b/gquiche/epoll_server/fake_simple_epoll_server.cc index 1e8b0b18..05e76b96 100644 --- a/gquiche/epoll_server/fake_simple_epoll_server.cc +++ b/gquiche/epoll_server/fake_simple_epoll_server.cc @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "epoll_server/fake_simple_epoll_server.h" +#include "gquiche/epoll_server/fake_simple_epoll_server.h" namespace epoll_server { namespace test { diff --git a/gquiche/epoll_server/fake_simple_epoll_server.h b/gquiche/epoll_server/fake_simple_epoll_server.h index 3c677862..8092b0e2 100644 --- a/gquiche/epoll_server/fake_simple_epoll_server.h +++ b/gquiche/epoll_server/fake_simple_epoll_server.h @@ -10,8 +10,8 @@ #include -#include "epoll_server/platform/api/epoll_export.h" -#include "epoll_server/simple_epoll_server.h" +#include "gquiche/epoll_server/platform/api/epoll_export.h" +#include "gquiche/epoll_server/simple_epoll_server.h" namespace epoll_server { namespace test { diff --git a/gquiche/epoll_server/simple_epoll_server_test.cc b/gquiche/epoll_server/simple_epoll_server_test.cc index efd47998..963a028d 100644 --- a/gquiche/epoll_server/simple_epoll_server_test.cc +++ b/gquiche/epoll_server/simple_epoll_server_test.cc @@ -5,7 +5,7 @@ // Epoll tests which determine that the right things happen in the right order. // Also lots of testing of individual functions. -#include "epoll_server/simple_epoll_server.h" +#include "gquiche/epoll_server/simple_epoll_server.h" #include #include @@ -24,13 +24,13 @@ #include #include -#include "epoll_server/fake_simple_epoll_server.h" -#include "epoll_server/platform/api/epoll_address_test_utils.h" -#include "epoll_server/platform/api/epoll_expect_bug.h" -#include "epoll_server/platform/api/epoll_ptr_util.h" -#include "epoll_server/platform/api/epoll_test.h" -#include "epoll_server/platform/api/epoll_thread.h" -#include "epoll_server/platform/api/epoll_time.h" +#include "gquiche/epoll_server/fake_simple_epoll_server.h" +#include "gquiche/epoll_server/platform/api/epoll_address_test_utils.h" +#include "gquiche/epoll_server/platform/api/epoll_expect_bug.h" +#include "gquiche/epoll_server/platform/api/epoll_ptr_util.h" +#include "gquiche/epoll_server/platform/api/epoll_test.h" +#include "gquiche/epoll_server/platform/api/epoll_thread.h" +#include "gquiche/epoll_server/platform/api/epoll_time.h" namespace epoll_server { diff --git a/gquiche/http2/adapter/adapter_impl_comparison_test.cc b/gquiche/http2/adapter/adapter_impl_comparison_test.cc new file mode 100644 index 00000000..41702128 --- /dev/null +++ b/gquiche/http2/adapter/adapter_impl_comparison_test.cc @@ -0,0 +1,83 @@ +#include "gquiche/http2/adapter/recording_http2_visitor.h" + +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/nghttp2_adapter.h" +#include "gquiche/http2/adapter/oghttp2_adapter.h" +#include "gquiche/http2/adapter/test_frame_sequence.h" +#include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +TEST(AdapterImplComparisonTest, ClientHandlesFrames) { + RecordingHttp2Visitor nghttp2_visitor; + std::unique_ptr nghttp2_adapter = + NgHttp2Adapter::CreateClientAdapter(nghttp2_visitor); + + RecordingHttp2Visitor oghttp2_visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + std::unique_ptr oghttp2_adapter = + OgHttp2Adapter::Create(oghttp2_visitor, options); + + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + + nghttp2_adapter->ProcessBytes(initial_frames); + oghttp2_adapter->ProcessBytes(initial_frames); + + EXPECT_EQ(nghttp2_visitor.GetEventSequence(), + oghttp2_visitor.GetEventSequence()); + + // TODO(b/181586191): Consider consistent behavior for delivering events on + // non-existent streams between nghttp2_adapter and oghttp2_adapter. +} + +TEST(AdapterImplComparisonTest, ServerHandlesFrames) { + RecordingHttp2Visitor nghttp2_visitor; + std::unique_ptr nghttp2_adapter = + NgHttp2Adapter::CreateServerAdapter(nghttp2_visitor); + + RecordingHttp2Visitor oghttp2_visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + std::unique_ptr oghttp2_adapter = + OgHttp2Adapter::Create(oghttp2_visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + + nghttp2_adapter->ProcessBytes(frames); + oghttp2_adapter->ProcessBytes(frames); + + EXPECT_EQ(nghttp2_visitor.GetEventSequence(), + oghttp2_visitor.GetEventSequence()); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/callback_visitor.cc b/gquiche/http2/adapter/callback_visitor.cc new file mode 100644 index 00000000..5d0b76e8 --- /dev/null +++ b/gquiche/http2/adapter/callback_visitor.cc @@ -0,0 +1,448 @@ +#include "gquiche/http2/adapter/callback_visitor.h" + +#include "absl/strings/escaping.h" +#include "gquiche/http2/adapter/http2_util.h" +#include "gquiche/http2/adapter/nghttp2_util.h" +#include "gquiche/common/quiche_endian.h" + +// This visitor implementation needs visibility into the +// nghttp2_session_callbacks type. There's no public header, so we'll redefine +// the struct here. +struct nghttp2_session_callbacks { + nghttp2_send_callback send_callback; + nghttp2_recv_callback recv_callback; + nghttp2_on_frame_recv_callback on_frame_recv_callback; + nghttp2_on_invalid_frame_recv_callback on_invalid_frame_recv_callback; + nghttp2_on_data_chunk_recv_callback on_data_chunk_recv_callback; + nghttp2_before_frame_send_callback before_frame_send_callback; + nghttp2_on_frame_send_callback on_frame_send_callback; + nghttp2_on_frame_not_send_callback on_frame_not_send_callback; + nghttp2_on_stream_close_callback on_stream_close_callback; + nghttp2_on_begin_headers_callback on_begin_headers_callback; + nghttp2_on_header_callback on_header_callback; + nghttp2_on_header_callback2 on_header_callback2; + nghttp2_on_invalid_header_callback on_invalid_header_callback; + nghttp2_on_invalid_header_callback2 on_invalid_header_callback2; + nghttp2_select_padding_callback select_padding_callback; + nghttp2_data_source_read_length_callback read_length_callback; + nghttp2_on_begin_frame_callback on_begin_frame_callback; + nghttp2_send_data_callback send_data_callback; + nghttp2_pack_extension_callback pack_extension_callback; + nghttp2_unpack_extension_callback unpack_extension_callback; + nghttp2_on_extension_chunk_recv_callback on_extension_chunk_recv_callback; + nghttp2_error_callback error_callback; + nghttp2_error_callback2 error_callback2; +}; + +namespace http2 { +namespace adapter { + +CallbackVisitor::CallbackVisitor(Perspective perspective, + const nghttp2_session_callbacks& callbacks, + void* user_data) + : perspective_(perspective), + callbacks_(MakeCallbacksPtr(nullptr)), + user_data_(user_data) { + nghttp2_session_callbacks* c; + nghttp2_session_callbacks_new(&c); + *c = callbacks; + callbacks_ = MakeCallbacksPtr(c); +} + +int64_t CallbackVisitor::OnReadyToSend(absl::string_view serialized) { + if (!callbacks_->send_callback) { + return kSendError; + } + int64_t result = callbacks_->send_callback( + nullptr, ToUint8Ptr(serialized.data()), serialized.size(), 0, user_data_); + QUICHE_VLOG(1) << "CallbackVisitor::OnReadyToSend called with " + << serialized.size() << " bytes, returning " << result; + QUICHE_VLOG(2) << (perspective_ == Perspective::kClient ? "Client" : "Server") + << " sending: [" << absl::CEscape(serialized) << "]"; + if (result > 0) { + return result; + } else if (result == NGHTTP2_ERR_WOULDBLOCK) { + return kSendBlocked; + } else { + return kSendError; + } +} + +void CallbackVisitor::OnConnectionError(ConnectionError /*error*/) { + QUICHE_LOG(ERROR) << "OnConnectionError not implemented"; +} + +bool CallbackVisitor::OnFrameHeader(Http2StreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + QUICHE_VLOG(1) << "CallbackVisitor::OnFrameHeader(stream_id=" << stream_id + << ", type=" << int(type) << ", length=" << length + << ", flags=" << int(flags) << ")"; + if (static_cast(type) == FrameType::CONTINUATION) { + // Treat CONTINUATION as HEADERS + QUICHE_DCHECK_EQ(current_frame_.hd.stream_id, stream_id); + current_frame_.hd.length += length; + current_frame_.hd.flags |= flags; + QUICHE_DLOG_IF(ERROR, length == 0) << "Empty CONTINUATION!"; + // Still need to deliver the CONTINUATION to the begin frame callback. + nghttp2_frame_hd hd; + memset(&hd, 0, sizeof(hd)); + hd.stream_id = stream_id; + hd.length = length; + hd.type = type; + hd.flags = flags; + if (callbacks_->on_begin_frame_callback) { + const int result = + callbacks_->on_begin_frame_callback(nullptr, &hd, user_data_); + return result == 0; + } + return true; + } + // The general strategy is to clear |current_frame_| at the start of a new + // frame, accumulate frame information from the various callback events, then + // invoke the on_frame_recv_callback() with the accumulated frame data. + memset(¤t_frame_, 0, sizeof(current_frame_)); + current_frame_.hd.stream_id = stream_id; + current_frame_.hd.length = length; + current_frame_.hd.type = type; + current_frame_.hd.flags = flags; + if (callbacks_->on_begin_frame_callback) { + const int result = callbacks_->on_begin_frame_callback( + nullptr, ¤t_frame_.hd, user_data_); + return result == 0; + } + return true; +} + +void CallbackVisitor::OnSettingsStart() {} + +void CallbackVisitor::OnSetting(Http2Setting setting) { + settings_.push_back({setting.id, setting.value}); +} + +void CallbackVisitor::OnSettingsEnd() { + current_frame_.settings.niv = settings_.size(); + current_frame_.settings.iv = settings_.data(); + QUICHE_VLOG(1) << "OnSettingsEnd, received settings of size " + << current_frame_.settings.niv; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } + settings_.clear(); +} + +void CallbackVisitor::OnSettingsAck() { + // ACK is part of the flags, which were set in OnFrameHeader(). + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +bool CallbackVisitor::OnBeginHeadersForStream(Http2StreamId stream_id) { + auto it = GetStreamInfo(stream_id); + if (it->second->received_headers) { + // At least one headers frame has already been received. + QUICHE_VLOG(1) + << "Headers already received for stream " << stream_id + << ", these are trailers or headers following a 100 response"; + current_frame_.headers.cat = NGHTTP2_HCAT_HEADERS; + } else { + switch (perspective_) { + case Perspective::kClient: + QUICHE_VLOG(1) << "First headers at the client for stream " << stream_id + << "; these are response headers"; + current_frame_.headers.cat = NGHTTP2_HCAT_RESPONSE; + break; + case Perspective::kServer: + QUICHE_VLOG(1) << "First headers at the server for stream " << stream_id + << "; these are request headers"; + current_frame_.headers.cat = NGHTTP2_HCAT_REQUEST; + break; + } + } + it->second->received_headers = true; + if (callbacks_->on_begin_headers_callback) { + const int result = callbacks_->on_begin_headers_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +Http2VisitorInterface::OnHeaderResult CallbackVisitor::OnHeaderForStream( + Http2StreamId /*stream_id*/, absl::string_view name, + absl::string_view value) { + if (callbacks_->on_header_callback) { + const int result = callbacks_->on_header_callback( + nullptr, ¤t_frame_, ToUint8Ptr(name.data()), name.size(), + ToUint8Ptr(value.data()), value.size(), NGHTTP2_NV_FLAG_NONE, + user_data_); + if (result == 0) { + return HEADER_OK; + } else if (result == NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE) { + return HEADER_RST_STREAM; + } else { + // Assume NGHTTP2_ERR_CALLBACK_FAILURE. + return HEADER_CONNECTION_ERROR; + } + } + return HEADER_OK; +} + +bool CallbackVisitor::OnEndHeadersForStream(Http2StreamId /*stream_id*/) { + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +bool CallbackVisitor::OnBeginDataForStream(Http2StreamId /*stream_id*/, + size_t payload_length) { + // TODO(b/181586191): Interpret padding, subtract padding from + // |remaining_data_|. + remaining_data_ = payload_length; + if (remaining_data_ == 0 && callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +bool CallbackVisitor::OnDataForStream(Http2StreamId stream_id, + absl::string_view data) { + int result = 0; + if (callbacks_->on_data_chunk_recv_callback) { + result = callbacks_->on_data_chunk_recv_callback( + nullptr, current_frame_.hd.flags, stream_id, ToUint8Ptr(data.data()), + data.size(), user_data_); + } + remaining_data_ -= data.size(); + if (result == 0 && remaining_data_ == 0 && + callbacks_->on_frame_recv_callback) { + result = callbacks_->on_frame_recv_callback(nullptr, ¤t_frame_, + user_data_); + } + return result == 0; +} + +void CallbackVisitor::OnEndStream(Http2StreamId /*stream_id*/) {} + +void CallbackVisitor::OnRstStream(Http2StreamId /*stream_id*/, + Http2ErrorCode error_code) { + current_frame_.rst_stream.error_code = static_cast(error_code); + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +void CallbackVisitor::OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + if (callbacks_->on_stream_close_callback) { + QUICHE_VLOG(1) << "OnCloseStream(stream_id: " << stream_id + << ", error_code: " << int(error_code) << ")"; + callbacks_->on_stream_close_callback( + nullptr, stream_id, static_cast(error_code), user_data_); + } +} + +void CallbackVisitor::OnPriorityForStream(Http2StreamId /*stream_id*/, + Http2StreamId parent_stream_id, + int weight, bool exclusive) { + current_frame_.priority.pri_spec.stream_id = parent_stream_id; + current_frame_.priority.pri_spec.weight = weight; + current_frame_.priority.pri_spec.exclusive = exclusive; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +void CallbackVisitor::OnPing(Http2PingId ping_id, bool /*is_ack*/) { + uint64_t network_order_opaque_data = + quiche::QuicheEndian::HostToNet64(ping_id); + std::memcpy(current_frame_.ping.opaque_data, &network_order_opaque_data, + sizeof(network_order_opaque_data)); + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +void CallbackVisitor::OnPushPromiseForStream( + Http2StreamId /*stream_id*/, Http2StreamId /*promised_stream_id*/) { + QUICHE_LOG(DFATAL) << "Not implemented"; +} + +bool CallbackVisitor::OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) { + current_frame_.goaway.last_stream_id = last_accepted_stream_id; + current_frame_.goaway.error_code = static_cast(error_code); + current_frame_.goaway.opaque_data = ToUint8Ptr(opaque_data.data()); + current_frame_.goaway.opaque_data_len = opaque_data.size(); + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + return result == 0; + } + return true; +} + +void CallbackVisitor::OnWindowUpdate(Http2StreamId /*stream_id*/, + int window_increment) { + current_frame_.window_update.window_size_increment = window_increment; + if (callbacks_->on_frame_recv_callback) { + const int result = callbacks_->on_frame_recv_callback( + nullptr, ¤t_frame_, user_data_); + QUICHE_DCHECK_EQ(0, result); + } +} + +void CallbackVisitor::PopulateFrame(nghttp2_frame& frame, uint8_t frame_type, + Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code, + bool sent_headers) { + frame.hd.type = frame_type; + frame.hd.stream_id = stream_id; + frame.hd.length = length; + frame.hd.flags = flags; + const FrameType frame_type_enum = static_cast(frame_type); + if (frame_type_enum == FrameType::HEADERS) { + if (sent_headers) { + frame.headers.cat = NGHTTP2_HCAT_HEADERS; + } else { + switch (perspective_) { + case Perspective::kClient: + QUICHE_VLOG(1) << "First headers sent by the client for stream " + << stream_id << "; these are request headers"; + frame.headers.cat = NGHTTP2_HCAT_REQUEST; + break; + case Perspective::kServer: + QUICHE_VLOG(1) << "First headers sent by the server for stream " + << stream_id << "; these are response headers"; + frame.headers.cat = NGHTTP2_HCAT_RESPONSE; + break; + } + } + } else if (frame_type_enum == FrameType::RST_STREAM) { + frame.rst_stream.error_code = error_code; + } else if (frame_type_enum == FrameType::GOAWAY) { + frame.goaway.error_code = error_code; + } +} + +int CallbackVisitor::OnBeforeFrameSent(uint8_t frame_type, + Http2StreamId stream_id, size_t length, + uint8_t flags) { + if (callbacks_->before_frame_send_callback) { + QUICHE_VLOG(1) << "OnBeforeFrameSent(stream_id=" << stream_id + << ", type=" << int(frame_type) << ", length=" << length + << ", flags=" << int(flags) << ")"; + nghttp2_frame frame; + auto it = GetStreamInfo(stream_id); + // The implementation of the before_frame_send_callback doesn't look at the + // error code, so for now it's populated with 0. + PopulateFrame(frame, frame_type, stream_id, length, flags, /*error_code=*/0, + it->second->before_sent_headers); + it->second->before_sent_headers = true; + return callbacks_->before_frame_send_callback(nullptr, &frame, user_data_); + } + return 0; +} + +int CallbackVisitor::OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags, + uint32_t error_code) { + if (callbacks_->on_frame_send_callback) { + QUICHE_VLOG(1) << "OnFrameSent(stream_id=" << stream_id + << ", type=" << int(frame_type) << ", length=" << length + << ", flags=" << int(flags) << ", error_code=" << error_code + << ")"; + nghttp2_frame frame; + auto it = GetStreamInfo(stream_id); + PopulateFrame(frame, frame_type, stream_id, length, flags, error_code, + it->second->sent_headers); + it->second->sent_headers = true; + return callbacks_->on_frame_send_callback(nullptr, &frame, user_data_); + } + return 0; +} + +bool CallbackVisitor::OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) { + QUICHE_VLOG(1) << "OnInvalidFrame(" << stream_id << ", " + << InvalidFrameErrorToString(error) << ")"; + QUICHE_DCHECK_EQ(stream_id, current_frame_.hd.stream_id); + if (callbacks_->on_invalid_frame_recv_callback) { + return 0 == + callbacks_->on_invalid_frame_recv_callback( + nullptr, ¤t_frame_, ToNgHttp2ErrorCode(error), user_data_); + } + return true; +} + +void CallbackVisitor::OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) { + QUICHE_VLOG(1) << "OnBeginMetadataForStream(stream_id=" << stream_id + << ", payload_length=" << payload_length << ")"; +} + +bool CallbackVisitor::OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) { + QUICHE_VLOG(1) << "OnMetadataForStream(stream_id=" << stream_id + << ", len=" << metadata.size() << ")"; + if (callbacks_->on_extension_chunk_recv_callback) { + int result = callbacks_->on_extension_chunk_recv_callback( + nullptr, ¤t_frame_.hd, ToUint8Ptr(metadata.data()), + metadata.size(), user_data_); + return result == 0; + } + return true; +} + +bool CallbackVisitor::OnMetadataEndForStream(Http2StreamId stream_id) { + QUICHE_LOG_IF(DFATAL, current_frame_.hd.flags != kMetadataEndFlag); + QUICHE_VLOG(1) << "OnMetadataEndForStream(stream_id=" << stream_id << ")"; + if (callbacks_->unpack_extension_callback) { + void* payload; + int result = callbacks_->unpack_extension_callback( + nullptr, &payload, ¤t_frame_.hd, user_data_); + if (result == 0 && callbacks_->on_frame_recv_callback) { + current_frame_.ext.payload = payload; + result = callbacks_->on_frame_recv_callback(nullptr, ¤t_frame_, + user_data_); + } + return (result == 0); + } + return true; +} + +void CallbackVisitor::OnErrorDebug(absl::string_view message) { + if (callbacks_->error_callback2) { + callbacks_->error_callback2(nullptr, -1, message.data(), message.size(), + user_data_); + } +} + +CallbackVisitor::StreamInfoMap::iterator CallbackVisitor::GetStreamInfo( + Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + auto p = stream_map_.insert({stream_id, absl::make_unique()}); + it = p.first; + } + return it; +} + +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/callback_visitor.h b/gquiche/http2/adapter/callback_visitor.h new file mode 100644 index 00000000..3a59675e --- /dev/null +++ b/gquiche/http2/adapter/callback_visitor.h @@ -0,0 +1,100 @@ +#ifndef QUICHE_HTTP2_ADAPTER_CALLBACK_VISITOR_H_ +#define QUICHE_HTTP2_ADAPTER_CALLBACK_VISITOR_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/adapter/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2_util.h" +#include "gquiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// This visitor implementation accepts a set of nghttp2 callbacks and a "user +// data" pointer, and invokes the callbacks according to HTTP/2 events received. +class QUICHE_EXPORT_PRIVATE CallbackVisitor : public Http2VisitorInterface { + public: + explicit CallbackVisitor(Perspective perspective, + const nghttp2_session_callbacks& callbacks, + void* user_data); + + int64_t OnReadyToSend(absl::string_view serialized) override; + void OnConnectionError(ConnectionError error) override; + bool OnFrameHeader(Http2StreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + void OnSettingsStart() override; + void OnSetting(Http2Setting setting) override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + bool OnBeginHeadersForStream(Http2StreamId stream_id) override; + OnHeaderResult OnHeaderForStream(Http2StreamId stream_id, + absl::string_view name, + absl::string_view value) override; + bool OnEndHeadersForStream(Http2StreamId stream_id) override; + bool OnBeginDataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnDataForStream(Http2StreamId stream_id, + absl::string_view data) override; + void OnEndStream(Http2StreamId stream_id) override; + void OnRstStream(Http2StreamId stream_id, Http2ErrorCode error_code) override; + void OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) override; + void OnPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, + bool exclusive) override; + void OnPing(Http2PingId ping_id, bool is_ack) override; + void OnPushPromiseForStream(Http2StreamId stream_id, + Http2StreamId promised_stream_id) override; + bool OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) override; + void OnWindowUpdate(Http2StreamId stream_id, int window_increment) override; + int OnBeforeFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags) override; + int OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code) override; + bool OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) override; + void OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) override; + bool OnMetadataEndForStream(Http2StreamId stream_id) override; + void OnErrorDebug(absl::string_view message) override; + + private: + struct QUICHE_EXPORT_PRIVATE StreamInfo { + bool before_sent_headers = false; + bool sent_headers = false; + bool received_headers = false; + }; + + using StreamInfoMap = + absl::flat_hash_map>; + + void PopulateFrame(nghttp2_frame& frame, uint8_t frame_type, + Http2StreamId stream_id, size_t length, uint8_t flags, + uint32_t error_code, bool sent_headers); + // Creates the StreamInfoMap entry if it doesn't exist. + StreamInfoMap::iterator GetStreamInfo(Http2StreamId stream_id); + + Perspective perspective_; + nghttp2_session_callbacks_unique_ptr callbacks_; + void* user_data_; + + nghttp2_frame current_frame_; + std::vector settings_; + size_t remaining_data_ = 0; + + StreamInfoMap stream_map_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_CALLBACK_VISITOR_H_ diff --git a/gquiche/http2/adapter/callback_visitor_test.cc b/gquiche/http2/adapter/callback_visitor_test.cc new file mode 100644 index 00000000..e8c00078 --- /dev/null +++ b/gquiche/http2/adapter/callback_visitor_test.cc @@ -0,0 +1,297 @@ +#include "gquiche/http2/adapter/callback_visitor.h" + +#include "gquiche/http2/adapter/mock_nghttp2_callbacks.h" +#include "gquiche/http2/adapter/nghttp2_test_utils.h" +#include "gquiche/http2/adapter/test_utils.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; + +// Tests connection-level events. +TEST(ClientCallbackVisitorUnitTest, ConnectionFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // SETTINGS + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, SETTINGS, _))); + visitor.OnFrameHeader(0, 0, SETTINGS, 0); + + visitor.OnSettingsStart(); + EXPECT_CALL(callbacks, OnFrameRecv(IsSettings(testing::IsEmpty()))); + visitor.OnSettingsEnd(); + + // PING + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, PING, _))); + visitor.OnFrameHeader(0, 8, PING, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPing(42))); + visitor.OnPing(42, false); + + // WINDOW_UPDATE + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, WINDOW_UPDATE, _))); + visitor.OnFrameHeader(0, 4, WINDOW_UPDATE, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsWindowUpdate(1000))); + visitor.OnWindowUpdate(0, 1000); + + // PING ack + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(0, PING, NGHTTP2_FLAG_ACK))); + visitor.OnFrameHeader(0, 8, PING, 1); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPingAck(247))); + visitor.OnPing(247, true); + + // GOAWAY + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, GOAWAY, 0))); + visitor.OnFrameHeader(0, 19, GOAWAY, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsGoAway(5, NGHTTP2_ENHANCE_YOUR_CALM, + "calm down!!"))); + visitor.OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!"); +} + +TEST(ClientCallbackVisitorUnitTest, StreamFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, HEADERS, _))); + visitor.OnFrameHeader(1, 23, HEADERS, 4); + + EXPECT_CALL(callbacks, + OnBeginHeaders(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_CALL(callbacks, OnHeader(_, ":status", "200", _)); + visitor.OnHeaderForStream(1, ":status", "200"); + + EXPECT_CALL(callbacks, OnHeader(_, "server", "my-fake-server", _)); + visitor.OnHeaderForStream(1, "server", "my-fake-server"); + + EXPECT_CALL(callbacks, + OnHeader(_, "date", "Tue, 6 Apr 2021 12:54:01 GMT", _)); + visitor.OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT"); + + EXPECT_CALL(callbacks, OnHeader(_, "trailer", "x-server-status", _)); + visitor.OnHeaderForStream(1, "trailer", "x-server-status"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnEndHeadersForStream(1); + + // DATA for stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, DATA, 0))); + visitor.OnFrameHeader(1, 26, DATA, 0); + + visitor.OnBeginDataForStream(1, 26); + EXPECT_CALL(callbacks, OnDataChunkRecv(0, 1, "This is the response body.")); + EXPECT_CALL(callbacks, OnFrameRecv(IsData(1, _, 0))); + visitor.OnDataForStream(1, "This is the response body."); + + // Trailers for stream 1, with a different nghttp2 "category". + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, HEADERS, _))); + visitor.OnFrameHeader(1, 23, HEADERS, 4); + + EXPECT_CALL(callbacks, OnBeginHeaders(IsHeaders(1, _, NGHTTP2_HCAT_HEADERS))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_CALL(callbacks, OnHeader(_, "x-server-status", "OK", _)); + visitor.OnHeaderForStream(1, "x-server-status", "OK"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, _, NGHTTP2_HCAT_HEADERS))); + visitor.OnEndHeadersForStream(1); + + // RST_STREAM on stream 3 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(3, RST_STREAM, 0))); + visitor.OnFrameHeader(3, 4, RST_STREAM, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(3, NGHTTP2_INTERNAL_ERROR))); + visitor.OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR); + + EXPECT_CALL(callbacks, OnStreamClose(3, NGHTTP2_INTERNAL_ERROR)); + visitor.OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR); + + // More stream close events + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(1, DATA, NGHTTP2_FLAG_END_STREAM))); + visitor.OnFrameHeader(1, 0, DATA, 1); + + EXPECT_CALL(callbacks, OnFrameRecv(IsData(1, _, NGHTTP2_FLAG_END_STREAM))); + visitor.OnBeginDataForStream(1, 0); + visitor.OnEndStream(1); + + EXPECT_CALL(callbacks, OnStreamClose(1, NGHTTP2_NO_ERROR)); + visitor.OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR); + + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(5, RST_STREAM, _))); + visitor.OnFrameHeader(5, 4, RST_STREAM, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(5, NGHTTP2_REFUSED_STREAM))); + visitor.OnRstStream(5, Http2ErrorCode::REFUSED_STREAM); + + EXPECT_CALL(callbacks, OnStreamClose(5, NGHTTP2_REFUSED_STREAM)); + visitor.OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM); +} + +TEST(ClientCallbackVisitorUnitTest, HeadersWithContinuation) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kClient, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, HEADERS, 0x0))); + visitor.OnFrameHeader(1, 23, HEADERS, 0x0); + + EXPECT_CALL(callbacks, + OnBeginHeaders(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_CALL(callbacks, OnHeader(_, ":status", "200", _)); + visitor.OnHeaderForStream(1, ":status", "200"); + + EXPECT_CALL(callbacks, OnHeader(_, "server", "my-fake-server", _)); + visitor.OnHeaderForStream(1, "server", "my-fake-server"); + + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(1, CONTINUATION, 0x4))); + visitor.OnFrameHeader(1, 23, CONTINUATION, 0x4); + + EXPECT_CALL(callbacks, + OnHeader(_, "date", "Tue, 6 Apr 2021 12:54:01 GMT", _)); + visitor.OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT"); + + EXPECT_CALL(callbacks, OnHeader(_, "trailer", "x-server-status", _)); + visitor.OnHeaderForStream(1, "trailer", "x-server-status"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, _, NGHTTP2_HCAT_RESPONSE))); + visitor.OnEndHeadersForStream(1); +} + +TEST(ServerCallbackVisitorUnitTest, ConnectionFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kServer, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // SETTINGS + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, SETTINGS, _))); + visitor.OnFrameHeader(0, 0, SETTINGS, 0); + + visitor.OnSettingsStart(); + EXPECT_CALL(callbacks, OnFrameRecv(IsSettings(testing::IsEmpty()))); + visitor.OnSettingsEnd(); + + // PING + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, PING, _))); + visitor.OnFrameHeader(0, 8, PING, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPing(42))); + visitor.OnPing(42, false); + + // WINDOW_UPDATE + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(0, WINDOW_UPDATE, _))); + visitor.OnFrameHeader(0, 4, WINDOW_UPDATE, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsWindowUpdate(1000))); + visitor.OnWindowUpdate(0, 1000); + + // PING ack + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(0, PING, NGHTTP2_FLAG_ACK))); + visitor.OnFrameHeader(0, 8, PING, 1); + + EXPECT_CALL(callbacks, OnFrameRecv(IsPingAck(247))); + visitor.OnPing(247, true); +} + +TEST(ServerCallbackVisitorUnitTest, StreamFrames) { + testing::StrictMock callbacks; + CallbackVisitor visitor(Perspective::kServer, + *MockNghttp2Callbacks::GetCallbacks(), &callbacks); + + testing::InSequence seq; + + // HEADERS on stream 1 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader( + 1, HEADERS, NGHTTP2_FLAG_END_HEADERS))); + visitor.OnFrameHeader(1, 23, HEADERS, 4); + + EXPECT_CALL(callbacks, OnBeginHeaders(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + visitor.OnBeginHeadersForStream(1); + + EXPECT_CALL(callbacks, OnHeader(_, ":method", "POST", _)); + visitor.OnHeaderForStream(1, ":method", "POST"); + + EXPECT_CALL(callbacks, OnHeader(_, ":path", "/example/path", _)); + visitor.OnHeaderForStream(1, ":path", "/example/path"); + + EXPECT_CALL(callbacks, OnHeader(_, ":scheme", "https", _)); + visitor.OnHeaderForStream(1, ":scheme", "https"); + + EXPECT_CALL(callbacks, OnHeader(_, ":authority", "example.com", _)); + visitor.OnHeaderForStream(1, ":authority", "example.com"); + + EXPECT_CALL(callbacks, OnHeader(_, "accept", "text/html", _)); + visitor.OnHeaderForStream(1, "accept", "text/html"); + + EXPECT_CALL(callbacks, OnFrameRecv(IsHeaders(1, NGHTTP2_FLAG_END_HEADERS, + NGHTTP2_HCAT_REQUEST))); + visitor.OnEndHeadersForStream(1); + + // DATA on stream 1 + EXPECT_CALL(callbacks, + OnBeginFrame(HasFrameHeader(1, DATA, NGHTTP2_FLAG_END_STREAM))); + visitor.OnFrameHeader(1, 25, DATA, NGHTTP2_FLAG_END_STREAM); + + visitor.OnBeginDataForStream(1, 25); + EXPECT_CALL(callbacks, OnDataChunkRecv(NGHTTP2_FLAG_END_STREAM, 1, + "This is the request body.")); + EXPECT_CALL(callbacks, OnFrameRecv(IsData(1, _, NGHTTP2_FLAG_END_STREAM))); + visitor.OnDataForStream(1, "This is the request body."); + visitor.OnEndStream(1); + + EXPECT_CALL(callbacks, OnStreamClose(1, NGHTTP2_NO_ERROR)); + visitor.OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR); + + // RST_STREAM on stream 3 + EXPECT_CALL(callbacks, OnBeginFrame(HasFrameHeader(3, RST_STREAM, 0))); + visitor.OnFrameHeader(3, 4, RST_STREAM, 0); + + EXPECT_CALL(callbacks, OnFrameRecv(IsRstStream(3, NGHTTP2_INTERNAL_ERROR))); + visitor.OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR); + + EXPECT_CALL(callbacks, OnStreamClose(3, NGHTTP2_INTERNAL_ERROR)); + visitor.OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/data_source.h b/gquiche/http2/adapter/data_source.h new file mode 100644 index 00000000..8acdbb4e --- /dev/null +++ b/gquiche/http2/adapter/data_source.h @@ -0,0 +1,56 @@ +#ifndef QUICHE_HTTP2_ADAPTER_DATA_SOURCE_H_ +#define QUICHE_HTTP2_ADAPTER_DATA_SOURCE_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "gquiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +// Represents a source of DATA frames for transmission to the peer. +class QUICHE_EXPORT_PRIVATE DataFrameSource { + public: + virtual ~DataFrameSource() {} + + enum : int64_t { kBlocked = 0, kError = -1 }; + + // Returns the number of bytes to send in the next DATA frame, and whether + // this frame indicates the end of the data. Returns {kBlocked, false} if + // blocked, {kError, false} on error. + virtual std::pair SelectPayloadLength(size_t max_length) = 0; + + // This method is called with a frame header and a payload length to send. The + // source should send or buffer the entire frame and return true, or return + // false without sending or buffering anything. + virtual bool Send(absl::string_view frame_header, size_t payload_length) = 0; + + // If true, the end of this data source indicates the end of the stream. + // Otherwise, this data will be followed by trailers. + virtual bool send_fin() const = 0; +}; + +// Represents a source of metadata frames for transmission to the peer. +class QUICHE_EXPORT_PRIVATE MetadataSource { + public: + virtual ~MetadataSource() {} + + // Returns the number of frames of at most |max_frame_size| required to + // serialize the metadata for this source. Only required by the nghttp2 + // implementation. + virtual size_t NumFrames(size_t max_frame_size) const = 0; + + // This method is called with a destination buffer and length. It should + // return the number of payload bytes copied to |dest|, or a negative integer + // to indicate an error, as well as a boolean indicating whether the metadata + // has been completely copied. + virtual std::pair Pack(uint8_t* dest, size_t dest_len) = 0; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_DATA_SOURCE_H_ diff --git a/gquiche/http2/adapter/event_forwarder.cc b/gquiche/http2/adapter/event_forwarder.cc new file mode 100644 index 00000000..b68d7800 --- /dev/null +++ b/gquiche/http2/adapter/event_forwarder.cc @@ -0,0 +1,181 @@ +#include "gquiche/http2/adapter/event_forwarder.h" + +namespace http2 { +namespace adapter { + +EventForwarder::EventForwarder(ForwardPredicate can_forward, + spdy::SpdyFramerVisitorInterface& receiver) + : can_forward_(std::move(can_forward)), receiver_(receiver) {} + +void EventForwarder::OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) { + if (can_forward_()) { + receiver_.OnError(error, std::move(detailed_error)); + } +} + +void EventForwarder::OnCommonHeader(spdy::SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + if (can_forward_()) { + receiver_.OnCommonHeader(stream_id, length, type, flags); + } +} + +void EventForwarder::OnDataFrameHeader(spdy::SpdyStreamId stream_id, + size_t length, bool fin) { + if (can_forward_()) { + receiver_.OnDataFrameHeader(stream_id, length, fin); + } +} + +void EventForwarder::OnStreamFrameData(spdy::SpdyStreamId stream_id, + const char* data, size_t len) { + if (can_forward_()) { + receiver_.OnStreamFrameData(stream_id, data, len); + } +} + +void EventForwarder::OnStreamEnd(spdy::SpdyStreamId stream_id) { + if (can_forward_()) { + receiver_.OnStreamEnd(stream_id); + } +} + +void EventForwarder::OnStreamPadLength(spdy::SpdyStreamId stream_id, + size_t value) { + if (can_forward_()) { + receiver_.OnStreamPadLength(stream_id, value); + } +} + +void EventForwarder::OnStreamPadding(spdy::SpdyStreamId stream_id, size_t len) { + if (can_forward_()) { + receiver_.OnStreamPadding(stream_id, len); + } +} + +spdy::SpdyHeadersHandlerInterface* EventForwarder::OnHeaderFrameStart( + spdy::SpdyStreamId stream_id) { + return receiver_.OnHeaderFrameStart(stream_id); +} + +void EventForwarder::OnHeaderFrameEnd(spdy::SpdyStreamId stream_id) { + if (can_forward_()) { + receiver_.OnHeaderFrameEnd(stream_id); + } +} + +void EventForwarder::OnRstStream(spdy::SpdyStreamId stream_id, + spdy::SpdyErrorCode error_code) { + if (can_forward_()) { + receiver_.OnRstStream(stream_id, error_code); + } +} + +void EventForwarder::OnSettings() { + if (can_forward_()) { + receiver_.OnSettings(); + } +} + +void EventForwarder::OnSetting(spdy::SpdySettingsId id, uint32_t value) { + if (can_forward_()) { + receiver_.OnSetting(id, value); + } +} + +void EventForwarder::OnSettingsEnd() { + if (can_forward_()) { + receiver_.OnSettingsEnd(); + } +} + +void EventForwarder::OnSettingsAck() { + if (can_forward_()) { + receiver_.OnSettingsAck(); + } +} + +void EventForwarder::OnPing(spdy::SpdyPingId unique_id, bool is_ack) { + if (can_forward_()) { + receiver_.OnPing(unique_id, is_ack); + } +} + +void EventForwarder::OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, + spdy::SpdyErrorCode error_code) { + if (can_forward_()) { + receiver_.OnGoAway(last_accepted_stream_id, error_code); + } +} + +bool EventForwarder::OnGoAwayFrameData(const char* goaway_data, size_t len) { + if (can_forward_()) { + return receiver_.OnGoAwayFrameData(goaway_data, len); + } + return false; +} + +void EventForwarder::OnHeaders(spdy::SpdyStreamId stream_id, bool has_priority, + int weight, spdy::SpdyStreamId parent_stream_id, + bool exclusive, bool fin, bool end) { + if (can_forward_()) { + receiver_.OnHeaders(stream_id, has_priority, weight, parent_stream_id, + exclusive, fin, end); + } +} + +void EventForwarder::OnWindowUpdate(spdy::SpdyStreamId stream_id, + int delta_window_size) { + if (can_forward_()) { + receiver_.OnWindowUpdate(stream_id, delta_window_size); + } +} + +void EventForwarder::OnPushPromise(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId promised_stream_id, + bool end) { + if (can_forward_()) { + receiver_.OnPushPromise(stream_id, promised_stream_id, end); + } +} + +void EventForwarder::OnContinuation(spdy::SpdyStreamId stream_id, bool end) { + if (can_forward_()) { + receiver_.OnContinuation(stream_id, end); + } +} + +void EventForwarder::OnAltSvc( + spdy::SpdyStreamId stream_id, absl::string_view origin, + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector& altsvc_vector) { + if (can_forward_()) { + receiver_.OnAltSvc(stream_id, origin, altsvc_vector); + } +} + +void EventForwarder::OnPriority(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId parent_stream_id, int weight, + bool exclusive) { + if (can_forward_()) { + receiver_.OnPriority(stream_id, parent_stream_id, weight, exclusive); + } +} + +void EventForwarder::OnPriorityUpdate(spdy::SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) { + if (can_forward_()) { + receiver_.OnPriorityUpdate(prioritized_stream_id, priority_field_value); + } +} + +bool EventForwarder::OnUnknownFrame(spdy::SpdyStreamId stream_id, + uint8_t frame_type) { + if (can_forward_()) { + return receiver_.OnUnknownFrame(stream_id, frame_type); + } + return false; +} + +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/event_forwarder.h b/gquiche/http2/adapter/event_forwarder.h new file mode 100644 index 00000000..4e906f53 --- /dev/null +++ b/gquiche/http2/adapter/event_forwarder.h @@ -0,0 +1,76 @@ +#ifndef QUICHE_HTTP2_ADAPTER_EVENT_FORWARDER_H_ +#define QUICHE_HTTP2_ADAPTER_EVENT_FORWARDER_H_ + +#include + +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/spdy/core/http2_frame_decoder_adapter.h" + +namespace http2 { +namespace adapter { + +// Forwards events to a provided SpdyFramerVisitorInterface receiver if the +// provided predicate succeeds. Currently, OnHeaderFrameStart() is always +// forwarded regardless of the predicate. +// TODO(diannahu): Add a NoOpHeadersHandler if needed. +class QUICHE_EXPORT_PRIVATE EventForwarder + : public spdy::SpdyFramerVisitorInterface { + public: + // Whether the forwarder can forward events to the receiver. + using ForwardPredicate = std::function; + + EventForwarder(ForwardPredicate can_forward, + spdy::SpdyFramerVisitorInterface& receiver); + + void OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) override; + void OnCommonHeader(spdy::SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + void OnDataFrameHeader(spdy::SpdyStreamId stream_id, size_t length, + bool fin) override; + void OnStreamFrameData(spdy::SpdyStreamId stream_id, const char* data, + size_t len) override; + void OnStreamEnd(spdy::SpdyStreamId stream_id) override; + void OnStreamPadLength(spdy::SpdyStreamId stream_id, size_t value) override; + void OnStreamPadding(spdy::SpdyStreamId stream_id, size_t len) override; + spdy::SpdyHeadersHandlerInterface* OnHeaderFrameStart( + spdy::SpdyStreamId stream_id) override; + void OnHeaderFrameEnd(spdy::SpdyStreamId stream_id) override; + void OnRstStream(spdy::SpdyStreamId stream_id, + spdy::SpdyErrorCode error_code) override; + void OnSettings() override; + void OnSetting(spdy::SpdySettingsId id, uint32_t value) override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + void OnPing(spdy::SpdyPingId unique_id, bool is_ack) override; + void OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, + spdy::SpdyErrorCode error_code) override; + bool OnGoAwayFrameData(const char* goaway_data, size_t len) override; + void OnHeaders(spdy::SpdyStreamId stream_id, bool has_priority, int weight, + spdy::SpdyStreamId parent_stream_id, bool exclusive, bool fin, + bool end) override; + void OnWindowUpdate(spdy::SpdyStreamId stream_id, + int delta_window_size) override; + void OnPushPromise(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId promised_stream_id, bool end) override; + void OnContinuation(spdy::SpdyStreamId stream_id, bool end) override; + void OnAltSvc(spdy::SpdyStreamId stream_id, absl::string_view origin, + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector& + altsvc_vector) override; + void OnPriority(spdy::SpdyStreamId stream_id, + spdy::SpdyStreamId parent_stream_id, int weight, + bool exclusive) override; + void OnPriorityUpdate(spdy::SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) override; + bool OnUnknownFrame(spdy::SpdyStreamId stream_id, + uint8_t frame_type) override; + + private: + ForwardPredicate can_forward_; + spdy::SpdyFramerVisitorInterface& receiver_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_EVENT_FORWARDER_H_ diff --git a/gquiche/http2/adapter/event_forwarder_test.cc b/gquiche/http2/adapter/event_forwarder_test.cc new file mode 100644 index 00000000..1a0d6050 --- /dev/null +++ b/gquiche/http2/adapter/event_forwarder_test.cc @@ -0,0 +1,220 @@ +#include "gquiche/http2/adapter/event_forwarder.h" + +#include + +#include "absl/strings/string_view.h" +#include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/spdy/core/mock_spdy_framer_visitor.h" +#include "gquiche/spdy/core/spdy_protocol.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +constexpr absl::string_view some_data = "Here is some data for events"; +constexpr spdy::SpdyStreamId stream_id = 1; +constexpr spdy::SpdyErrorCode error_code = + spdy::SpdyErrorCode::ERROR_CODE_ENHANCE_YOUR_CALM; +constexpr size_t length = 42; + +TEST(EventForwarderTest, ForwardsEventsWithTruePredicate) { + spdy::test::MockSpdyFramerVisitor receiver; + receiver.DelegateHeaderHandling(); + EventForwarder event_forwarder([]() { return true; }, receiver); + + EXPECT_CALL( + receiver, + OnError(Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + std::string(some_data))); + event_forwarder.OnError( + Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + std::string(some_data)); + + EXPECT_CALL(receiver, + OnCommonHeader(stream_id, length, /*type=*/0x0, /*flags=*/0x1)); + event_forwarder.OnCommonHeader(stream_id, length, /*type=*/0x0, + /*flags=*/0x1); + + EXPECT_CALL(receiver, OnDataFrameHeader(stream_id, length, /*fin=*/true)); + event_forwarder.OnDataFrameHeader(stream_id, length, /*fin=*/true); + + EXPECT_CALL(receiver, + OnStreamFrameData(stream_id, some_data.data(), some_data.size())); + event_forwarder.OnStreamFrameData(stream_id, some_data.data(), + some_data.size()); + + EXPECT_CALL(receiver, OnStreamEnd(stream_id)); + event_forwarder.OnStreamEnd(stream_id); + + EXPECT_CALL(receiver, OnStreamPadLength(stream_id, length)); + event_forwarder.OnStreamPadLength(stream_id, length); + + EXPECT_CALL(receiver, OnStreamPadding(stream_id, length)); + event_forwarder.OnStreamPadding(stream_id, length); + + EXPECT_CALL(receiver, OnHeaderFrameStart(stream_id)); + spdy::SpdyHeadersHandlerInterface* handler = + event_forwarder.OnHeaderFrameStart(stream_id); + EXPECT_EQ(handler, receiver.ReturnTestHeadersHandler(stream_id)); + + EXPECT_CALL(receiver, OnHeaderFrameEnd(stream_id)); + event_forwarder.OnHeaderFrameEnd(stream_id); + + EXPECT_CALL(receiver, OnRstStream(stream_id, error_code)); + event_forwarder.OnRstStream(stream_id, error_code); + + EXPECT_CALL(receiver, OnSettings()); + event_forwarder.OnSettings(); + + EXPECT_CALL( + receiver, + OnSetting(spdy::SpdyKnownSettingsId::SETTINGS_MAX_CONCURRENT_STREAMS, + 100)); + event_forwarder.OnSetting( + spdy::SpdyKnownSettingsId::SETTINGS_MAX_CONCURRENT_STREAMS, 100); + + EXPECT_CALL(receiver, OnSettingsEnd()); + event_forwarder.OnSettingsEnd(); + + EXPECT_CALL(receiver, OnSettingsAck()); + event_forwarder.OnSettingsAck(); + + EXPECT_CALL(receiver, OnPing(/*unique_id=*/42, /*is_ack=*/false)); + event_forwarder.OnPing(/*unique_id=*/42, /*is_ack=*/false); + + EXPECT_CALL(receiver, OnGoAway(stream_id, error_code)); + event_forwarder.OnGoAway(stream_id, error_code); + + EXPECT_CALL(receiver, OnGoAwayFrameData(some_data.data(), some_data.size())); + event_forwarder.OnGoAwayFrameData(some_data.data(), some_data.size()); + + EXPECT_CALL( + receiver, + OnHeaders(stream_id, /*has_priority=*/false, /*weight=*/42, stream_id + 2, + /*exclusive=*/false, /*fin=*/true, /*end=*/true)); + event_forwarder.OnHeaders(stream_id, /*has_priority=*/false, /*weight=*/42, + stream_id + 2, /*exclusive=*/false, /*fin=*/true, + /*end=*/true); + + EXPECT_CALL(receiver, OnWindowUpdate(stream_id, /*delta_window_size=*/42)); + event_forwarder.OnWindowUpdate(stream_id, /*delta_window_size=*/42); + + EXPECT_CALL(receiver, OnPushPromise(stream_id, stream_id + 1, /*end=*/true)); + event_forwarder.OnPushPromise(stream_id, stream_id + 1, /*end=*/true); + + EXPECT_CALL(receiver, OnContinuation(stream_id, /*end=*/true)); + event_forwarder.OnContinuation(stream_id, /*end=*/true); + + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + EXPECT_CALL(receiver, OnAltSvc(stream_id, some_data, altsvc_vector)); + event_forwarder.OnAltSvc(stream_id, some_data, altsvc_vector); + + EXPECT_CALL(receiver, OnPriority(stream_id, stream_id + 2, /*weight=*/42, + /*exclusive=*/false)); + event_forwarder.OnPriority(stream_id, stream_id + 2, /*weight=*/42, + /*exclusive=*/false); + + EXPECT_CALL(receiver, OnPriorityUpdate(stream_id, some_data)); + event_forwarder.OnPriorityUpdate(stream_id, some_data); + + EXPECT_CALL(receiver, OnUnknownFrame(stream_id, /*frame_type=*/0x4D)); + event_forwarder.OnUnknownFrame(stream_id, /*frame_type=*/0x4D); +} + +TEST(EventForwarderTest, DoesNotForwardEventsWithFalsePredicate) { + spdy::test::MockSpdyFramerVisitor receiver; + receiver.DelegateHeaderHandling(); + EventForwarder event_forwarder([]() { return false; }, receiver); + + EXPECT_CALL(receiver, OnError).Times(0); + event_forwarder.OnError( + Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + std::string(some_data)); + + EXPECT_CALL(receiver, OnCommonHeader).Times(0); + event_forwarder.OnCommonHeader(stream_id, length, /*type=*/0x0, + /*flags=*/0x1); + + EXPECT_CALL(receiver, OnDataFrameHeader).Times(0); + event_forwarder.OnDataFrameHeader(stream_id, length, /*fin=*/true); + + EXPECT_CALL(receiver, OnStreamFrameData).Times(0); + event_forwarder.OnStreamFrameData(stream_id, some_data.data(), + some_data.size()); + + EXPECT_CALL(receiver, OnStreamEnd).Times(0); + event_forwarder.OnStreamEnd(stream_id); + + EXPECT_CALL(receiver, OnStreamPadLength).Times(0); + event_forwarder.OnStreamPadLength(stream_id, length); + + EXPECT_CALL(receiver, OnStreamPadding).Times(0); + event_forwarder.OnStreamPadding(stream_id, length); + + EXPECT_CALL(receiver, OnHeaderFrameStart(stream_id)); + spdy::SpdyHeadersHandlerInterface* handler = + event_forwarder.OnHeaderFrameStart(stream_id); + EXPECT_EQ(handler, receiver.ReturnTestHeadersHandler(stream_id)); + + EXPECT_CALL(receiver, OnHeaderFrameEnd).Times(0); + event_forwarder.OnHeaderFrameEnd(stream_id); + + EXPECT_CALL(receiver, OnRstStream).Times(0); + event_forwarder.OnRstStream(stream_id, error_code); + + EXPECT_CALL(receiver, OnSettings).Times(0); + event_forwarder.OnSettings(); + + EXPECT_CALL(receiver, OnSetting).Times(0); + event_forwarder.OnSetting( + spdy::SpdyKnownSettingsId::SETTINGS_MAX_CONCURRENT_STREAMS, 100); + + EXPECT_CALL(receiver, OnSettingsEnd).Times(0); + event_forwarder.OnSettingsEnd(); + + EXPECT_CALL(receiver, OnSettingsAck).Times(0); + event_forwarder.OnSettingsAck(); + + EXPECT_CALL(receiver, OnPing).Times(0); + event_forwarder.OnPing(/*unique_id=*/42, /*is_ack=*/false); + + EXPECT_CALL(receiver, OnGoAway).Times(0); + event_forwarder.OnGoAway(stream_id, error_code); + + EXPECT_CALL(receiver, OnGoAwayFrameData).Times(0); + event_forwarder.OnGoAwayFrameData(some_data.data(), some_data.size()); + + EXPECT_CALL(receiver, OnHeaders).Times(0); + event_forwarder.OnHeaders(stream_id, /*has_priority=*/false, /*weight=*/42, + stream_id + 2, /*exclusive=*/false, /*fin=*/true, + /*end=*/true); + + EXPECT_CALL(receiver, OnWindowUpdate).Times(0); + event_forwarder.OnWindowUpdate(stream_id, /*delta_window_size=*/42); + + EXPECT_CALL(receiver, OnPushPromise).Times(0); + event_forwarder.OnPushPromise(stream_id, stream_id + 1, /*end=*/true); + + EXPECT_CALL(receiver, OnContinuation).Times(0); + event_forwarder.OnContinuation(stream_id, /*end=*/true); + + EXPECT_CALL(receiver, OnAltSvc).Times(0); + const spdy::SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; + event_forwarder.OnAltSvc(stream_id, some_data, altsvc_vector); + + EXPECT_CALL(receiver, OnPriority).Times(0); + event_forwarder.OnPriority(stream_id, stream_id + 2, /*weight=*/42, + /*exclusive=*/false); + + EXPECT_CALL(receiver, OnPriorityUpdate).Times(0); + event_forwarder.OnPriorityUpdate(stream_id, some_data); + + EXPECT_CALL(receiver, OnUnknownFrame).Times(0); + event_forwarder.OnUnknownFrame(stream_id, /*frame_type=*/0x4D); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/header_validator.cc b/gquiche/http2/adapter/header_validator.cc new file mode 100644 index 00000000..3d45754e --- /dev/null +++ b/gquiche/http2/adapter/header_validator.cc @@ -0,0 +1,104 @@ +#include "gquiche/http2/adapter/header_validator.h" + +#include "absl/strings/escaping.h" +#include "gquiche/common/platform/api/quiche_logging.h" + +namespace http2 { +namespace adapter { + +namespace { + +const absl::string_view kHttp2HeaderNameAllowedChars = + "!#$%&\'*+-.0123456789" + "^_`abcdefghijklmnopqrstuvwxyz|~"; + +const absl::string_view kHttp2HeaderValueAllowedChars = + "\t " + "!\"#$%&'()*+,-./" + "0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`" + "abcdefghijklmnopqrstuvwxyz{|}~"; + +const absl::string_view kHttp2StatusValueAllowedChars = "0123456789"; + +// TODO(birenroy): Support websocket requests, which contain an extra +// `:protocol` pseudo-header. +bool ValidateRequestHeaders(const std::vector& pseudo_headers) { + static const std::vector* kRequiredHeaders = + new std::vector( + {":authority", ":method", ":path", ":scheme"}); + return pseudo_headers == *kRequiredHeaders; +} + +bool ValidateResponseHeaders(const std::vector& pseudo_headers) { + static const std::vector* kRequiredHeaders = + new std::vector({":status"}); + return pseudo_headers == *kRequiredHeaders; +} + +bool ValidateResponseTrailers(const std::vector& pseudo_headers) { + return pseudo_headers.empty(); +} + +} // namespace + +void HeaderValidator::StartHeaderBlock() { + pseudo_headers_.clear(); + status_.clear(); +} + +HeaderValidator::HeaderStatus HeaderValidator::ValidateSingleHeader( + absl::string_view key, absl::string_view value) { + if (key.empty()) { + return HEADER_NAME_EMPTY; + } + const absl::string_view validated_key = key[0] == ':' ? key.substr(1) : key; + if (validated_key.find_first_not_of(kHttp2HeaderNameAllowedChars) != + absl::string_view::npos) { + QUICHE_VLOG(2) << "invalid chars in header name: [" + << absl::CEscape(validated_key) << "]"; + return HEADER_NAME_INVALID_CHAR; + } + if (value.find_first_not_of(kHttp2HeaderValueAllowedChars) != + absl::string_view::npos) { + QUICHE_VLOG(2) << "invalid chars in header value: [" << absl::CEscape(value) + << "]"; + return HEADER_VALUE_INVALID_CHAR; + } + if (key[0] == ':') { + if (key == ":status") { + if (value.size() != 3 || + value.find_first_not_of(kHttp2StatusValueAllowedChars) != + absl::string_view::npos) { + QUICHE_VLOG(2) << "malformed status value: [" << absl::CEscape(value) + << "]"; + return HEADER_VALUE_INVALID_CHAR; + } + if (value == "101") { + // Switching protocols is not allowed on a HTTP/2 stream. + return HEADER_VALUE_INVALID_STATUS; + } + status_ = std::string(value); + } + pseudo_headers_.push_back(std::string(key)); + } + return HEADER_OK; +} + +// Returns true if all required pseudoheaders and no extra pseudoheaders are +// present for the given header type. +bool HeaderValidator::FinishHeaderBlock(HeaderType type) { + std::sort(pseudo_headers_.begin(), pseudo_headers_.end()); + switch (type) { + case HeaderType::REQUEST: + return ValidateRequestHeaders(pseudo_headers_); + case HeaderType::RESPONSE_100: + case HeaderType::RESPONSE: + return ValidateResponseHeaders(pseudo_headers_); + case HeaderType::RESPONSE_TRAILER: + return ValidateResponseTrailers(pseudo_headers_); + } + return false; +} + +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/header_validator.h b/gquiche/http2/adapter/header_validator.h new file mode 100644 index 00000000..dc8b5df9 --- /dev/null +++ b/gquiche/http2/adapter/header_validator.h @@ -0,0 +1,50 @@ +#ifndef QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_H_ +#define QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "gquiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +enum class HeaderType : uint8_t { + REQUEST, + RESPONSE_100, + RESPONSE, + RESPONSE_TRAILER, +}; + +class QUICHE_EXPORT_PRIVATE HeaderValidator { + public: + HeaderValidator() {} + + void StartHeaderBlock(); + + enum HeaderStatus { + HEADER_OK, + HEADER_NAME_EMPTY, + HEADER_NAME_INVALID_CHAR, + HEADER_VALUE_INVALID_CHAR, + HEADER_VALUE_INVALID_STATUS, + }; + HeaderStatus ValidateSingleHeader(absl::string_view key, + absl::string_view value); + + // Returns true if all required pseudoheaders and no extra pseudoheaders are + // present for the given header type. + bool FinishHeaderBlock(HeaderType type); + + absl::string_view status_header() const { return status_; } + + private: + std::vector pseudo_headers_; + std::string status_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_HEADER_VALIDATOR_H_ diff --git a/gquiche/http2/adapter/header_validator_test.cc b/gquiche/http2/adapter/header_validator_test.cc new file mode 100644 index 00000000..64cf1b87 --- /dev/null +++ b/gquiche/http2/adapter/header_validator_test.cc @@ -0,0 +1,213 @@ +#include "gquiche/http2/adapter/header_validator.h" + +#include "absl/strings/str_cat.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +TEST(HeaderValidatorTest, HeaderNameEmpty) { + HeaderValidator v; + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader("", "value"); + EXPECT_EQ(HeaderValidator::HEADER_NAME_EMPTY, status); +} + +TEST(HeaderValidatorTest, HeaderValueEmpty) { + HeaderValidator v; + HeaderValidator::HeaderStatus status = v.ValidateSingleHeader("name", ""); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); +} + +TEST(HeaderValidatorTest, NameHasInvalidChar) { + HeaderValidator v; + for (const bool is_pseudo_header : {true, false}) { + // These characters should be allowed. (Not exhaustive.) + for (const char* c : {"!", "3", "a", "_", "|", "~"}) { + const std::string name = is_pseudo_header ? absl::StrCat(":met", c, "hod") + : absl::StrCat("na", c, "me"); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(name, "value"); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + // These should not. (Not exhaustive.) + for (const char* c : {"\\", "<", ";", "[", "=", " ", "\r", "\n", ",", "\"", + "\x1F", "\x91"}) { + const std::string name = is_pseudo_header ? absl::StrCat(":met", c, "hod") + : absl::StrCat("na", c, "me"); + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(name, "value"); + EXPECT_EQ(HeaderValidator::HEADER_NAME_INVALID_CHAR, status); + } + // Uppercase characters in header names should not be allowed. + const std::string uc_name = is_pseudo_header ? ":Method" : "Name"; + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader(uc_name, "value"); + EXPECT_EQ(HeaderValidator::HEADER_NAME_INVALID_CHAR, status); + } +} + +TEST(HeaderValidatorTest, ValueHasInvalidChar) { + HeaderValidator v; + // These characters should be allowed. (Not exhaustive.) + for (const char* c : + {"!", "3", "a", "_", "|", "~", "\\", "<", ";", "[", "=", "A", "\t"}) { + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("name", absl::StrCat("val", c, "ue")); + EXPECT_EQ(HeaderValidator::HEADER_OK, status); + } + // These should not. + for (const char* c : {"\r", "\n"}) { + HeaderValidator::HeaderStatus status = + v.ValidateSingleHeader("name", absl::StrCat("val", c, "ue")); + EXPECT_EQ(HeaderValidator::HEADER_VALUE_INVALID_CHAR, status); + } +} + +TEST(HeaderValidatorTest, StatusHasInvalidChar) { + HeaderValidator v; + + for (HeaderType type : {HeaderType::RESPONSE, HeaderType::RESPONSE_100}) { + // When `:status` has a non-digit value, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_VALUE_INVALID_CHAR, + v.ValidateSingleHeader(":status", "bar")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When `:status` is too short, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_VALUE_INVALID_CHAR, + v.ValidateSingleHeader(":status", "10")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When `:status` is too long, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_VALUE_INVALID_CHAR, + v.ValidateSingleHeader(":status", "9000")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When `:status` is just right, validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "400")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + } +} + +TEST(HeaderValidatorTest, RequestPseudoHeaders) { + HeaderValidator v; + const absl::string_view headers[] = {":authority", ":method", ":path", + ":scheme"}; + for (absl::string_view to_skip : headers) { + v.StartHeaderBlock(); + for (absl::string_view to_add : headers) { + if (to_add != to_skip) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add, "foo")); + } + } + // When any pseudo-header is missing, final validation will fail. + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + } + + // When all pseudo-headers are present, final validation will succeed. + v.StartHeaderBlock(); + for (absl::string_view to_add : headers) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add, "foo")); + } + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // When an extra pseudo-header is present, final validation will fail. + v.StartHeaderBlock(); + for (absl::string_view to_add : headers) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add, "foo")); + } + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":extra", "blah")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + + // When a required pseudo-header is repeated, final validation will fail. + for (absl::string_view to_repeat : headers) { + v.StartHeaderBlock(); + for (absl::string_view to_add : headers) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add, "foo")); + if (to_add == to_repeat) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add, "foo")); + } + } + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); + } +} + +TEST(HeaderValidatorTest, WebsocketPseudoHeaders) { + HeaderValidator v; + const absl::string_view headers[] = {":authority", ":method", ":path", + ":scheme"}; + v.StartHeaderBlock(); + for (absl::string_view to_add : headers) { + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(to_add, "foo")); + } + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":protocol", "websocket")); + // For now, `:protocol` is treated as an extra pseudo-header. + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::REQUEST)); +} + +TEST(HeaderValidatorTest, ResponsePseudoHeaders) { + HeaderValidator v; + + for (HeaderType type : {HeaderType::RESPONSE, HeaderType::RESPONSE_100}) { + // When `:status` is missing, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, v.ValidateSingleHeader("foo", "bar")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When all pseudo-headers are present, final validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_TRUE(v.FinishHeaderBlock(type)); + EXPECT_EQ("199", v.status_header()); + + // When `:status` is repeated, validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "299")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + + // When an extra pseudo-header is present, final validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "199")); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":extra", "blorp")); + EXPECT_FALSE(v.FinishHeaderBlock(type)); + } +} + +TEST(HeaderValidatorTest, ResponseTrailerPseudoHeaders) { + HeaderValidator v; + + // When no pseudo-headers are present, validation will succeed. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, v.ValidateSingleHeader("foo", "bar")); + EXPECT_TRUE(v.FinishHeaderBlock(HeaderType::RESPONSE_TRAILER)); + + // When any pseudo-header is present, final validation will fail. + v.StartHeaderBlock(); + EXPECT_EQ(HeaderValidator::HEADER_OK, + v.ValidateSingleHeader(":status", "200")); + EXPECT_EQ(HeaderValidator::HEADER_OK, v.ValidateSingleHeader("foo", "bar")); + EXPECT_FALSE(v.FinishHeaderBlock(HeaderType::RESPONSE_TRAILER)); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/http2_adapter.h b/gquiche/http2/adapter/http2_adapter.h index c55f1416..8154990b 100644 --- a/gquiche/http2/adapter/http2_adapter.h +++ b/gquiche/http2/adapter/http2_adapter.h @@ -1,12 +1,16 @@ #ifndef QUICHE_HTTP2_ADAPTER_HTTP2_ADAPTER_H_ #define QUICHE_HTTP2_ADAPTER_HTTP2_ADAPTER_H_ +#include + #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" +#include "gquiche/http2/adapter/data_source.h" #include "gquiche/http2/adapter/http2_protocol.h" #include "gquiche/http2/adapter/http2_session.h" #include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/common/platform/api/quiche_export.h" namespace http2 { namespace adapter { @@ -17,14 +21,21 @@ namespace adapter { // invokes corresponding callbacks on its passed-in Http2VisitorInterface. // Http2Adapter is a base class shared between client-side and server-side // implementations. -class Http2Adapter { +class QUICHE_EXPORT_PRIVATE Http2Adapter { public: Http2Adapter(const Http2Adapter&) = delete; Http2Adapter& operator=(const Http2Adapter&) = delete; + virtual ~Http2Adapter() {} + + virtual bool IsServerSession() const = 0; + + virtual bool want_read() const = 0; + virtual bool want_write() const = 0; + // Processes the incoming |bytes| as HTTP/2 and invokes callbacks on the // |visitor_| as appropriate. - virtual ssize_t ProcessBytes(absl::string_view bytes) = 0; + virtual int64_t ProcessBytes(absl::string_view bytes) = 0; // Submits the |settings| to be written to the peer, e.g., as part of the // HTTP/2 connection preface. @@ -36,17 +47,15 @@ class Http2Adapter { int weight, bool exclusive) = 0; - // Submits a PING on the connection. Note that nghttp2 automatically submits - // PING acks upon receiving non-ack PINGs from the peer, so callers only use - // this method to originate PINGs. See nghttp2_option_set_no_auto_ping_ack(). + // Submits a PING on the connection. virtual void SubmitPing(Http2PingId ping_id) = 0; + // Starts a graceful shutdown. A no-op for clients. + virtual void SubmitShutdownNotice() = 0; + // Submits a GOAWAY on the connection. Note that |last_accepted_stream_id| - // refers to stream IDs initiated by the peer. For client-side, this last - // stream ID must be even (or 0); for server-side, this last stream ID must be - // odd (or 0). To submit a GOAWAY with |last_accepted_stream_id| with the - // maximum stream ID, signaling imminent connection termination, call - // SubmitShutdownNotice() instead (though this is only possible server-side). + // refers to stream IDs initiated by the peer. For a server sending this + // frame, this last stream ID must be odd (or 0). virtual void SubmitGoAway(Http2StreamId last_accepted_stream_id, Http2ErrorCode error_code, absl::string_view opaque_data) = 0; @@ -56,33 +65,90 @@ class Http2Adapter { virtual void SubmitWindowUpdate(Http2StreamId stream_id, int window_increment) = 0; - // Submits a METADATA frame for the given stream (a |stream_id| of 0 indicates - // connection-level METADATA). If |fin|, the frame will also have the - // END_METADATA flag set. - virtual void SubmitMetadata(Http2StreamId stream_id, bool fin) = 0; + // Submits a RST_STREAM for the given |stream_id| and |error_code|. + virtual void SubmitRst(Http2StreamId stream_id, + Http2ErrorCode error_code) = 0; + + // Submits a sequence of METADATA frames for the given stream. A |stream_id| + // of 0 indicates connection-level METADATA. + virtual void SubmitMetadata(Http2StreamId stream_id, size_t max_frame_size, + std::unique_ptr source) = 0; + + // Invokes the visitor's OnReadyToSend() method for serialized frame data. + // Returns 0 on success. + virtual int Send() = 0; - // Returns serialized bytes for writing to the wire. - // Writes should be submitted to Http2Adapter first, so that Http2Adapter - // has data to serialize and return in this method. - virtual std::string GetBytesToWrite(absl::optional max_bytes) = 0; + // Returns the connection-level flow control window advertised by the peer. + virtual int GetSendWindowSize() const = 0; - // Returns the connection-level flow control window for the peer. - virtual int GetPeerConnectionWindow() const = 0; + // Returns the stream-level flow control window advertised by the peer. + virtual int GetStreamSendWindowSize(Http2StreamId stream_id) const = 0; + + // Returns the current upper bound on the flow control receive window for this + // stream. This value does not account for data received from the peer. + virtual int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const = 0; + + // Returns the amount of data a peer could send on a given stream. This is + // the outstanding stream receive window. + virtual int GetStreamReceiveWindowSize(Http2StreamId stream_id) const = 0; + + // Returns the total amount of data a peer could send on the connection. This + // is the outstanding connection receive window. + virtual int GetReceiveWindowSize() const = 0; + + // Returns the size of the HPACK encoder's dynamic table, including the + // per-entry overhead from the specification. + virtual int GetHpackEncoderDynamicTableSize() const = 0; + + // Returns the size of the HPACK decoder's dynamic table, including the + // per-entry overhead from the specification. + virtual int GetHpackDecoderDynamicTableSize() const = 0; + + // Gets the highest stream ID value seen in a frame received by this endpoint. + // This method is only guaranteed to work for server endpoints. + virtual Http2StreamId GetHighestReceivedStreamId() const = 0; // Marks the given amount of data as consumed for the given stream, which - // enables the nghttp2 layer to trigger WINDOW_UPDATEs as appropriate. + // enables the implementation layer to send WINDOW_UPDATEs as appropriate. virtual void MarkDataConsumedForStream(Http2StreamId stream_id, size_t num_bytes) = 0; - // Submits a RST_STREAM for the given stream. - virtual void SubmitRst(Http2StreamId stream_id, - Http2ErrorCode error_code) = 0; + // Returns the assigned stream ID if the operation succeeds. Otherwise, + // returns a negative integer indicating an error code. |data_source| may be + // nullptr if the request does not have a body. + virtual int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data) = 0; + + // Returns 0 on success. |data_source| may be nullptr if the response does not + // have a body. + virtual int SubmitResponse(Http2StreamId stream_id, + absl::Span headers, + std::unique_ptr data_source) = 0; + + // Queues trailers to be sent after any outstanding data on the stream with ID + // |stream_id|. Returns 0 on success. + virtual int SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) = 0; + + // Sets a user data pointer for the given stream. Can be called after + // SubmitRequest/SubmitResponse, or after receiving any frame for a given + // stream. + virtual void SetStreamUserData(Http2StreamId stream_id, void* user_data) = 0; + + // Returns nullptr if the stream does not exist, or if stream user data has + // not been set. + virtual void* GetStreamUserData(Http2StreamId stream_id) = 0; + + // Resumes a stream that was previously blocked (for example, due to + // DataFrameSource::SelectPayloadLength() returning kBlocked). Returns true if + // the stream was successfully resumed. + virtual bool ResumeStream(Http2StreamId stream_id) = 0; protected: // Subclasses should expose a public factory method for constructing and // initializing (via Initialize()) adapter instances. explicit Http2Adapter(Http2VisitorInterface& visitor) : visitor_(visitor) {} - virtual ~Http2Adapter() {} // Accessors. Do not transfer ownership. Http2VisitorInterface& visitor() { return visitor_; } diff --git a/gquiche/http2/adapter/http2_protocol.cc b/gquiche/http2/adapter/http2_protocol.cc index f60d56e3..be68887b 100644 --- a/gquiche/http2/adapter/http2_protocol.cc +++ b/gquiche/http2/adapter/http2_protocol.cc @@ -12,6 +12,23 @@ const char kHttp2AuthorityPseudoHeader[] = ":authority"; const char kHttp2PathPseudoHeader[] = ":path"; const char kHttp2StatusPseudoHeader[] = ":status"; +const uint8_t kMetadataFrameType = 0x4d; +const uint8_t kMetadataEndFlag = 0x04; +const uint16_t kMetadataExtensionId = 0x4d44; + +std::pair GetStringView(const HeaderRep& rep) { + if (absl::holds_alternative(rep)) { + return std::make_pair(absl::get(rep), true); + } else { + absl::string_view view = absl::get(rep); + return std::make_pair(view, false); + } +} + +bool operator==(const Http2Setting& a, const Http2Setting& b) { + return a.id == b.id && a.value == b.value; +} + absl::string_view Http2SettingsIdToString(uint16_t id) { switch (id) { case Http2KnownSettingsId::HEADER_TABLE_SIZE: @@ -32,8 +49,8 @@ absl::string_view Http2SettingsIdToString(uint16_t id) { absl::string_view Http2ErrorCodeToString(Http2ErrorCode error_code) { switch (error_code) { - case Http2ErrorCode::NO_ERROR: - return "NO_ERROR"; + case Http2ErrorCode::HTTP2_NO_ERROR: + return "HTTP2_NO_ERROR"; case Http2ErrorCode::PROTOCOL_ERROR: return "PROTOCOL_ERROR"; case Http2ErrorCode::INTERNAL_ERROR: diff --git a/gquiche/http2/adapter/http2_protocol.h b/gquiche/http2/adapter/http2_protocol.h index 49e5f161..83b3417d 100644 --- a/gquiche/http2/adapter/http2_protocol.h +++ b/gquiche/http2/adapter/http2_protocol.h @@ -5,9 +5,10 @@ #include #include -#include "base/integral_types.h" #include "absl/base/attributes.h" #include "absl/strings/string_view.h" +#include "absl/types/variant.h" +#include "gquiche/common/platform/api/quiche_export.h" namespace http2 { namespace adapter { @@ -21,16 +22,26 @@ using Http2SettingsId = uint16_t; // Represents the payload of an HTTP/2 PING frame. using Http2PingId = uint64_t; +// Represents a single header name or value. +using HeaderRep = absl::variant; + +// Boolean return value is true if |rep| holds a string_view, which is assumed +// to have an indefinite lifetime. +std::pair GetStringView(const HeaderRep& rep); + // Represents an HTTP/2 header field. A header field is a key-value pair with // lowercase keys (as specified in RFC 7540 Section 8.1.2). -using Header = std::pair; +using Header = std::pair; // Represents an HTTP/2 SETTINGS key-value parameter. -struct Http2Setting { +struct QUICHE_EXPORT_PRIVATE Http2Setting { Http2SettingsId id; uint32_t value; }; +QUICHE_EXPORT_PRIVATE bool operator==(const Http2Setting& a, + const Http2Setting& b); + // The maximum possible stream ID. const Http2StreamId kMaxStreamId = 0x7FFFFFFF; @@ -42,21 +53,44 @@ const Http2StreamId kConnectionStreamId = 0; // 7540 Section 6.5.2 (SETTINGS_MAX_FRAME_SIZE). const int kDefaultFramePayloadSizeLimit = 16 * 1024; -// The default value for the initial stream flow control window size, according -// to RFC 7540 Section 6.9.2. -const int kDefaultInitialStreamWindowSize = 64 * 1024 - 1; +// The default value for the initial stream and connection flow control window +// size, according to RFC 7540 Section 6.9.2. +const int kInitialFlowControlWindowSize = 64 * 1024 - 1; // The pseudo-header fields as specified in RFC 7540 Section 8.1.2.3 (request) // and Section 8.1.2.4 (response). -ABSL_CONST_INIT extern const char kHttp2MethodPseudoHeader[]; -ABSL_CONST_INIT extern const char kHttp2SchemePseudoHeader[]; -ABSL_CONST_INIT extern const char kHttp2AuthorityPseudoHeader[]; -ABSL_CONST_INIT extern const char kHttp2PathPseudoHeader[]; -ABSL_CONST_INIT extern const char kHttp2StatusPseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const char + kHttp2MethodPseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const char + kHttp2SchemePseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const char + kHttp2AuthorityPseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const char + kHttp2PathPseudoHeader[]; +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const char + kHttp2StatusPseudoHeader[]; + +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const uint8_t kMetadataFrameType; +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const uint8_t kMetadataEndFlag; +ABSL_CONST_INIT QUICHE_EXPORT_PRIVATE extern const uint16_t + kMetadataExtensionId; + +enum class FrameType : uint8_t { + DATA = 0x0, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; // HTTP/2 error codes as specified in RFC 7540 Section 7. enum class Http2ErrorCode { - NO_ERROR = 0x0, + HTTP2_NO_ERROR = 0x0, PROTOCOL_ERROR = 0x1, INTERNAL_ERROR = 0x2, FLOW_CONTROL_ERROR = 0x3, @@ -86,7 +120,8 @@ enum Http2KnownSettingsId : Http2SettingsId { INITIAL_WINDOW_SIZE = 0x4, MAX_FRAME_SIZE = 0x5, MAX_HEADER_LIST_SIZE = 0x6, - MAX_SETTING = MAX_HEADER_LIST_SIZE + ENABLE_CONNECT_PROTOCOL = 0x8, // See RFC 8441 + MAX_SETTING = ENABLE_CONNECT_PROTOCOL }; // Returns a human-readable string representation of the given SETTINGS |id| for @@ -99,6 +134,11 @@ absl::string_view Http2SettingsIdToString(uint16_t id); // Section 7 definitions. absl::string_view Http2ErrorCodeToString(Http2ErrorCode error_code); +enum class Perspective { + kClient, + kServer, +}; + } // namespace adapter } // namespace http2 diff --git a/gquiche/http2/adapter/http2_session.h b/gquiche/http2/adapter/http2_session.h index f2eed525..99d62520 100644 --- a/gquiche/http2/adapter/http2_session.h +++ b/gquiche/http2/adapter/http2_session.h @@ -5,19 +5,20 @@ #include "absl/strings/string_view.h" #include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/common/platform/api/quiche_export.h" namespace http2 { namespace adapter { -struct Http2SessionCallbacks {}; +struct QUICHE_EXPORT_PRIVATE Http2SessionCallbacks {}; // A class to represent the state of a single HTTP/2 connection. -class Http2Session { +class QUICHE_EXPORT_PRIVATE Http2Session { public: Http2Session() = default; virtual ~Http2Session() {} - virtual ssize_t ProcessBytes(absl::string_view bytes) = 0; + virtual int64_t ProcessBytes(absl::string_view bytes) = 0; virtual int Consume(Http2StreamId stream_id, size_t num_bytes) = 0; @@ -26,11 +27,6 @@ class Http2Session { virtual int GetRemoteWindowSize() const = 0; }; -enum class Perspective { - kClient, - kServer, -}; - } // namespace adapter } // namespace http2 diff --git a/gquiche/http2/adapter/http2_util.cc b/gquiche/http2/adapter/http2_util.cc index ce7024ba..ae694e7c 100644 --- a/gquiche/http2/adapter/http2_util.cc +++ b/gquiche/http2/adapter/http2_util.cc @@ -1,11 +1,19 @@ #include "gquiche/http2/adapter/http2_util.h" +#include "gquiche/spdy/core/spdy_protocol.h" + namespace http2 { namespace adapter { +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; +using InvalidFrameError = Http2VisitorInterface::InvalidFrameError; + +} // anonymous namespace spdy::SpdyErrorCode TranslateErrorCode(Http2ErrorCode code) { switch (code) { - case Http2ErrorCode::NO_ERROR: + case Http2ErrorCode::HTTP2_NO_ERROR: return spdy::ERROR_CODE_NO_ERROR; case Http2ErrorCode::PROTOCOL_ERROR: return spdy::ERROR_CODE_PROTOCOL_ERROR; @@ -34,12 +42,13 @@ spdy::SpdyErrorCode TranslateErrorCode(Http2ErrorCode code) { case Http2ErrorCode::HTTP_1_1_REQUIRED: return spdy::ERROR_CODE_HTTP_1_1_REQUIRED; } + return spdy::ERROR_CODE_INTERNAL_ERROR; } Http2ErrorCode TranslateErrorCode(spdy::SpdyErrorCode code) { switch (code) { case spdy::ERROR_CODE_NO_ERROR: - return Http2ErrorCode::NO_ERROR; + return Http2ErrorCode::HTTP2_NO_ERROR; case spdy::ERROR_CODE_PROTOCOL_ERROR: return Http2ErrorCode::PROTOCOL_ERROR; case spdy::ERROR_CODE_INTERNAL_ERROR: @@ -67,6 +76,48 @@ Http2ErrorCode TranslateErrorCode(spdy::SpdyErrorCode code) { case spdy::ERROR_CODE_HTTP_1_1_REQUIRED: return Http2ErrorCode::HTTP_1_1_REQUIRED; } + return Http2ErrorCode::INTERNAL_ERROR; +} + +absl::string_view ConnectionErrorToString(ConnectionError error) { + switch (error) { + case ConnectionError::kInvalidConnectionPreface: + return "InvalidConnectionPreface"; + case ConnectionError::kSendError: + return "SendError"; + case ConnectionError::kParseError: + return "ParseError"; + case ConnectionError::kHeaderError: + return "HeaderError"; + case ConnectionError::kInvalidNewStreamId: + return "InvalidNewStreamId"; + case ConnectionError::kWrongFrameSequence: + return "kWrongFrameSequence"; + case ConnectionError::kInvalidPushPromise: + return "InvalidPushPromise"; + case ConnectionError::kExceededMaxConcurrentStreams: + return "ExceededMaxConcurrentStreams"; + } + return "UnknownConnectionError"; +} + +absl::string_view InvalidFrameErrorToString( + Http2VisitorInterface::InvalidFrameError error) { + switch (error) { + case InvalidFrameError::kProtocol: + return "Protocol"; + case InvalidFrameError::kRefusedStream: + return "RefusedStream"; + case InvalidFrameError::kHttpHeader: + return "HttpHeader"; + case InvalidFrameError::kHttpMessaging: + return "HttpMessaging"; + case InvalidFrameError::kFlowControl: + return "FlowControl"; + case InvalidFrameError::kStreamClosed: + return "StreamClosed"; + } + return "UnknownInvalidFrameError"; } } // namespace adapter diff --git a/gquiche/http2/adapter/http2_util.h b/gquiche/http2/adapter/http2_util.h index 90e9744c..29584d4d 100644 --- a/gquiche/http2/adapter/http2_util.h +++ b/gquiche/http2/adapter/http2_util.h @@ -2,13 +2,23 @@ #define QUICHE_HTTP2_ADAPTER_HTTP2_UTIL_H_ #include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/common/platform/api/quiche_export.h" #include "gquiche/spdy/core/spdy_protocol.h" namespace http2 { namespace adapter { -spdy::SpdyErrorCode TranslateErrorCode(Http2ErrorCode code); -Http2ErrorCode TranslateErrorCode(spdy::SpdyErrorCode code); +QUICHE_EXPORT_PRIVATE spdy::SpdyErrorCode TranslateErrorCode( + Http2ErrorCode code); +QUICHE_EXPORT_PRIVATE Http2ErrorCode +TranslateErrorCode(spdy::SpdyErrorCode code); + +QUICHE_EXPORT_PRIVATE absl::string_view ConnectionErrorToString( + Http2VisitorInterface::ConnectionError error); + +QUICHE_EXPORT_PRIVATE absl::string_view InvalidFrameErrorToString( + Http2VisitorInterface::InvalidFrameError error); } // namespace adapter } // namespace http2 diff --git a/gquiche/http2/adapter/http2_visitor_interface.h b/gquiche/http2/adapter/http2_visitor_interface.h index 9605e58f..14f14f0d 100644 --- a/gquiche/http2/adapter/http2_visitor_interface.h +++ b/gquiche/http2/adapter/http2_visitor_interface.h @@ -1,10 +1,12 @@ #ifndef QUICHE_HTTP2_ADAPTER_HTTP2_VISITOR_INTERFACE_H_ #define QUICHE_HTTP2_ADAPTER_HTTP2_VISITOR_INTERFACE_H_ +#include #include #include "absl/strings/string_view.h" #include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/common/platform/api/quiche_export.h" namespace http2 { namespace adapter { @@ -34,7 +36,7 @@ namespace adapter { // - OnHeaderForStream() // - OnEndHeadersForStream() // - OnRstStream() -// - OnAbortStream() +// - OnCloseStream() // // Request closed mid-stream, e.g., with error code NO_ERROR: // - OnBeginHeadersForStream() @@ -43,16 +45,47 @@ namespace adapter { // - OnRstStream() // - OnCloseStream() // -// More details are at RFC 7540 (go/http2spec), and more examples are at -// http://google3/net/http2/server/lib/internal/h2/nghttp2/nghttp2_server_adapter_test.cc. -class Http2VisitorInterface { +// More details are at RFC 7540 (go/http2spec). +class QUICHE_EXPORT_PRIVATE Http2VisitorInterface { public: Http2VisitorInterface(const Http2VisitorInterface&) = delete; Http2VisitorInterface& operator=(const Http2VisitorInterface&) = delete; virtual ~Http2VisitorInterface() = default; - // Called when a connection-level processing error has been encountered. - virtual void OnConnectionError() = 0; + enum : int64_t { + kSendBlocked = 0, + kSendError = -1, + }; + // Called when there are serialized frames to send. Should return how many + // bytes were actually sent. May return kSendBlocked or kSendError. + virtual int64_t OnReadyToSend(absl::string_view serialized) = 0; + + // Called when a connection-level error has occurred. + enum class ConnectionError { + // The peer sent an invalid connection preface. + kInvalidConnectionPreface, + // The visitor encountered an error sending bytes to the peer. + kSendError, + // There was an error reading and framing bytes from the peer. + kParseError, + // The visitor considered a received header to be a connection error. + kHeaderError, + // The peer attempted to open a stream with an invalid stream ID. + kInvalidNewStreamId, + // The peer sent a frame that is invalid on an idle stream (before HEADERS). + kWrongFrameSequence, + // The peer sent an invalid PUSH_PROMISE frame. + kInvalidPushPromise, + // The peer exceeded the max concurrent streams limit. + kExceededMaxConcurrentStreams, + }; + virtual void OnConnectionError(ConnectionError error) = 0; + + // Called when the header for a frame is received. + virtual bool OnFrameHeader(Http2StreamId /*stream_id*/, size_t /*length*/, + uint8_t /*type*/, uint8_t /*flags*/) { + return true; + } // Called when a non-ack SETTINGS frame is received. virtual void OnSettingsStart() = 0; @@ -67,29 +100,48 @@ class Http2VisitorInterface { virtual void OnSettingsAck() = 0; // Called when the connection receives the header block for a HEADERS frame on - // a stream but has not yet parsed individual headers. - virtual void OnBeginHeadersForStream(Http2StreamId stream_id) = 0; + // a stream but has not yet parsed individual headers. Returns false if a + // fatal error has occurred. + virtual bool OnBeginHeadersForStream(Http2StreamId stream_id) = 0; // Called when the connection receives the header |key| and |value| for a // stream. The HTTP/2 pseudo-headers defined in RFC 7540 Sections 8.1.2.3 and // 8.1.2.4 are also conveyed in this callback. This method is called after - // OnBeginHeadersForStream(). - virtual void OnHeaderForStream(Http2StreamId stream_id, absl::string_view key, - absl::string_view value) = 0; + // OnBeginHeadersForStream(). May return HEADER_RST_STREAM to indicate the + // header block should be rejected. This will cause the library to queue a + // RST_STREAM frame, which will have a default error code of INTERNAL_ERROR. + // The visitor implementation may choose to queue a RST_STREAM with a + // different error code instead, which should be done before returning + // HEADER_RST_STREAM. Returning HEADER_CONNECTION_ERROR will lead to a + // non-recoverable error on the connection. + enum OnHeaderResult { + // The header was accepted. + HEADER_OK, + // The application considers the header a connection error. + HEADER_CONNECTION_ERROR, + // The application rejects the header and requests the stream be reset. + HEADER_RST_STREAM, + // The header is a violation of HTTP messaging semantics and will be reset + // with error code PROTOCOL_ERROR. + HEADER_HTTP_MESSAGING, + }; + virtual OnHeaderResult OnHeaderForStream(Http2StreamId stream_id, + absl::string_view key, + absl::string_view value) = 0; // Called when the connection has received the complete header block for a // logical HEADERS frame on a stream (which may contain CONTINUATION frames, // transparent to the user). - virtual void OnEndHeadersForStream(Http2StreamId stream_id) = 0; + virtual bool OnEndHeadersForStream(Http2StreamId stream_id) = 0; // Called when the connection receives the beginning of a DATA frame. The data // payload will be provided via subsequent calls to OnDataForStream(). - virtual void OnBeginDataForStream(Http2StreamId stream_id, + virtual bool OnBeginDataForStream(Http2StreamId stream_id, size_t payload_length) = 0; // Called when the connection receives some |data| (as part of a DATA frame // payload) for a stream. - virtual void OnDataForStream(Http2StreamId stream_id, + virtual bool OnDataForStream(Http2StreamId stream_id, absl::string_view data) = 0; // Called when the peer sends the END_STREAM flag on a stream, indicating that @@ -97,18 +149,12 @@ class Http2VisitorInterface { virtual void OnEndStream(Http2StreamId stream_id) = 0; // Called when the connection receives a RST_STREAM for a stream. This call - // will be followed by either OnCloseStream() or OnAbortStream(). + // will be followed by either OnCloseStream(). virtual void OnRstStream(Http2StreamId stream_id, Http2ErrorCode error_code) = 0; - // Called when a stream is closed with error code NO_ERROR. Compare with - // OnAbortStream(). - virtual void OnCloseStream(Http2StreamId stream_id) = 0; - - // Called when a stream is aborted, i.e., closed for the reason indicated by - // the given |error_code|, where error_code != NO_ERROR. Compare with - // OnCloseStream(). - virtual void OnAbortStream(Http2StreamId stream_id, + // Called when a stream is closed. + virtual void OnCloseStream(Http2StreamId stream_id, Http2ErrorCode error_code) = 0; // Called when the connection receives a PRIORITY frame. @@ -125,7 +171,7 @@ class Http2VisitorInterface { Http2StreamId promised_stream_id) = 0; // Called when the connection receives a GOAWAY frame. - virtual void OnGoAway(Http2StreamId last_accepted_stream_id, + virtual bool OnGoAway(Http2StreamId last_accepted_stream_id, Http2ErrorCode error_code, absl::string_view opaque_data) = 0; @@ -134,39 +180,57 @@ class Http2VisitorInterface { virtual void OnWindowUpdate(Http2StreamId stream_id, int window_increment) = 0; - // Called when the connection is ready to send data for a stream. The - // implementation should write at most |length| bytes of the data payload to - // the |destination_buffer| and set |end_stream| to true IFF there will be no - // more data sent on this stream. Sets |written| to the number of bytes - // written to the |destination_buffer| or a negative value if an error occurs. - virtual void OnReadyToSendDataForStream(Http2StreamId stream_id, - char* destination_buffer, - size_t length, - ssize_t* written, - bool* end_stream) = 0; - - // Called when the connection is ready to write metadata for |stream_id| to - // the wire. The implementation should write at most |length| bytes of the - // serialized metadata payload to the |buffer| and set |written| to the number - // of bytes written or a negative value if there was an error. - virtual void OnReadyToSendMetadataForStream(Http2StreamId stream_id, - char* buffer, size_t length, - ssize_t* written) = 0; + // Called immediately before a frame of the given type is sent. Should return + // 0 on success. + virtual int OnBeforeFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags) = 0; + + // Called immediately after a frame of the given type is sent. Should return 0 + // on success. |error_code| is only populated for RST_STREAM and GOAWAY frame + // types. + virtual int OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags, + uint32_t error_code) = 0; + + // Called when the connection receives an invalid frame. A return value of + // false will result in the connection entering an error state, with no + // further frame processing possible. + enum class InvalidFrameError { + // The frame contains a general protocol error. + kProtocol, + // The frame would have caused a new (invalid) stream to be opened. + kRefusedStream, + // The frame contains an invalid header field. + kHttpHeader, + // The frame contains a violation in HTTP messaging rules. + kHttpMessaging, + // The frame causes a flow control error. + kFlowControl, + // The frame is on an already closed stream or has an invalid stream ID. + kStreamClosed, + }; + virtual bool OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) = 0; // Called when the connection receives the beginning of a METADATA frame // (which may itself be the middle of a logical metadata block). The metadata // payload will be provided via subsequent calls to OnMetadataForStream(). + // TODO(birenroy): Consider removing this unnecessary method. virtual void OnBeginMetadataForStream(Http2StreamId stream_id, size_t payload_length) = 0; // Called when the connection receives |metadata| as part of a METADATA frame - // payload for a stream. - virtual void OnMetadataForStream(Http2StreamId stream_id, + // payload for a stream. Returns false if a fatal error has occurred. + virtual bool OnMetadataForStream(Http2StreamId stream_id, absl::string_view metadata) = 0; // Called when the connection has finished receiving a logical metadata block // for a stream. Note that there may be multiple metadata blocks for a stream. - virtual void OnMetadataEndForStream(Http2StreamId stream_id) = 0; + // Returns false if there was an error unpacking the metadata payload. + virtual bool OnMetadataEndForStream(Http2StreamId stream_id) = 0; + + // Invoked with an error message from the application. + virtual void OnErrorDebug(absl::string_view message) = 0; protected: Http2VisitorInterface() = default; diff --git a/gquiche/http2/adapter/mock_http2_visitor.h b/gquiche/http2/adapter/mock_http2_visitor.h index c132d7a6..ed2c7c2e 100644 --- a/gquiche/http2/adapter/mock_http2_visitor.h +++ b/gquiche/http2/adapter/mock_http2_visitor.h @@ -1,7 +1,10 @@ #ifndef QUICHE_HTTP2_ADAPTER_MOCK_HTTP2_VISITOR_INTERFACE_H_ #define QUICHE_HTTP2_ADAPTER_MOCK_HTTP2_VISITOR_INTERFACE_H_ +#include + #include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/common/platform/api/quiche_export.h" #include "gquiche/common/platform/api/quiche_test.h" namespace http2 { @@ -9,41 +12,49 @@ namespace adapter { namespace test { // A mock visitor class, for use in tests. -class MockHttp2Visitor : public Http2VisitorInterface { +class QUICHE_NO_EXPORT MockHttp2Visitor : public Http2VisitorInterface { public: - MockHttp2Visitor() = default; - - MOCK_METHOD(void, OnConnectionError, (), (override)); + MockHttp2Visitor() { + ON_CALL(*this, OnFrameHeader).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnBeginHeadersForStream) + .WillByDefault(testing::Return(true)); + ON_CALL(*this, OnHeaderForStream).WillByDefault(testing::Return(HEADER_OK)); + ON_CALL(*this, OnEndHeadersForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnBeginDataForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnDataForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnGoAway).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnInvalidFrame).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnMetadataForStream).WillByDefault(testing::Return(true)); + ON_CALL(*this, OnMetadataEndForStream).WillByDefault(testing::Return(true)); + } + + MOCK_METHOD(int64_t, OnReadyToSend, (absl::string_view serialized), + (override)); + MOCK_METHOD(void, OnConnectionError, (ConnectionError error), (override)); + MOCK_METHOD(bool, OnFrameHeader, + (Http2StreamId stream_id, size_t length, uint8_t type, + uint8_t flags), + (override)); MOCK_METHOD(void, OnSettingsStart, (), (override)); MOCK_METHOD(void, OnSetting, (Http2Setting setting), (override)); MOCK_METHOD(void, OnSettingsEnd, (), (override)); MOCK_METHOD(void, OnSettingsAck, (), (override)); - MOCK_METHOD(void, - OnBeginHeadersForStream, - (Http2StreamId stream_id), + MOCK_METHOD(bool, OnBeginHeadersForStream, (Http2StreamId stream_id), (override)); - MOCK_METHOD(void, - OnHeaderForStream, - (Http2StreamId stream_id, - absl::string_view key, + MOCK_METHOD(OnHeaderResult, OnHeaderForStream, + (Http2StreamId stream_id, absl::string_view key, absl::string_view value), (override)); - MOCK_METHOD(void, - OnEndHeadersForStream, - (Http2StreamId stream_id), + MOCK_METHOD(bool, OnEndHeadersForStream, (Http2StreamId stream_id), (override)); - MOCK_METHOD(void, - OnBeginDataForStream, - (Http2StreamId stream_id, size_t payload_length), - (override)); + MOCK_METHOD(bool, OnBeginDataForStream, + (Http2StreamId stream_id, size_t payload_length), (override)); - MOCK_METHOD(void, - OnDataForStream, - (Http2StreamId stream_id, absl::string_view data), - (override)); + MOCK_METHOD(bool, OnDataForStream, + (Http2StreamId stream_id, absl::string_view data), (override)); MOCK_METHOD(void, OnEndStream, (Http2StreamId stream_id), (override)); @@ -52,10 +63,8 @@ class MockHttp2Visitor : public Http2VisitorInterface { (Http2StreamId stream_id, Http2ErrorCode error_code), (override)); - MOCK_METHOD(void, OnCloseStream, (Http2StreamId stream_id), (override)); - MOCK_METHOD(void, - OnAbortStream, + OnCloseStream, (Http2StreamId stream_id, Http2ErrorCode error_code), (override)); @@ -74,10 +83,8 @@ class MockHttp2Visitor : public Http2VisitorInterface { (Http2StreamId stream_id, Http2StreamId promised_stream_id), (override)); - MOCK_METHOD(void, - OnGoAway, - (Http2StreamId last_accepted_stream_id, - Http2ErrorCode error_code, + MOCK_METHOD(bool, OnGoAway, + (Http2StreamId last_accepted_stream_id, Http2ErrorCode error_code, absl::string_view opaque_data), (override)); @@ -86,35 +93,32 @@ class MockHttp2Visitor : public Http2VisitorInterface { (Http2StreamId stream_id, int window_increment), (override)); - MOCK_METHOD(void, - OnReadyToSendDataForStream, - (Http2StreamId stream_id, - char* destination_buffer, - size_t length, - ssize_t* written, - bool* end_stream), + MOCK_METHOD(int, OnBeforeFrameSent, + (uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags), (override)); - MOCK_METHOD( - void, - OnReadyToSendMetadataForStream, - (Http2StreamId stream_id, char* buffer, size_t length, ssize_t* written), - (override)); + MOCK_METHOD(int, OnFrameSent, + (uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code), + (override)); + + MOCK_METHOD(bool, OnInvalidFrame, + (Http2StreamId stream_id, InvalidFrameError error), (override)); MOCK_METHOD(void, OnBeginMetadataForStream, (Http2StreamId stream_id, size_t payload_length), (override)); - MOCK_METHOD(void, - OnMetadataForStream, + MOCK_METHOD(bool, OnMetadataForStream, (Http2StreamId stream_id, absl::string_view metadata), (override)); - MOCK_METHOD(void, - OnMetadataEndForStream, - (Http2StreamId stream_id), + MOCK_METHOD(bool, OnMetadataEndForStream, (Http2StreamId stream_id), (override)); + + MOCK_METHOD(void, OnErrorDebug, (absl::string_view message), (override)); }; } // namespace test diff --git a/gquiche/http2/adapter/mock_nghttp2_callbacks.cc b/gquiche/http2/adapter/mock_nghttp2_callbacks.cc new file mode 100644 index 00000000..e04ff88c --- /dev/null +++ b/gquiche/http2/adapter/mock_nghttp2_callbacks.cc @@ -0,0 +1,130 @@ +#include "gquiche/http2/adapter/mock_nghttp2_callbacks.h" + +#include "gquiche/http2/adapter/nghttp2_util.h" + +namespace http2 { +namespace adapter { +namespace test { + +/* static */ +nghttp2_session_callbacks_unique_ptr MockNghttp2Callbacks::GetCallbacks() { + nghttp2_session_callbacks* callbacks; + nghttp2_session_callbacks_new(&callbacks); + + // All of the callback implementations below just delegate to the mock methods + // of |user_data|, which is assumed to be a MockNghttp2Callbacks*. + nghttp2_session_callbacks_set_send_callback( + callbacks, + [](nghttp2_session*, const uint8_t* data, size_t length, int flags, + void* user_data) -> ssize_t { + return static_cast(user_data)->Send(data, length, + flags); + }); + + nghttp2_session_callbacks_set_send_data_callback( + callbacks, + [](nghttp2_session*, nghttp2_frame* frame, const uint8_t* framehd, + size_t length, nghttp2_data_source* source, void* user_data) -> int { + return static_cast(user_data)->SendData( + frame, framehd, length, source); + }); + + nghttp2_session_callbacks_set_on_begin_headers_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->OnBeginHeaders( + frame); + }); + + nghttp2_session_callbacks_set_on_header_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, const uint8_t* raw_name, + size_t name_length, const uint8_t* raw_value, size_t value_length, + uint8_t flags, void* user_data) -> int { + absl::string_view name = ToStringView(raw_name, name_length); + absl::string_view value = ToStringView(raw_value, value_length); + return static_cast(user_data)->OnHeader( + frame, name, value, flags); + }); + + nghttp2_session_callbacks_set_on_data_chunk_recv_callback( + callbacks, + [](nghttp2_session*, uint8_t flags, int32_t stream_id, + const uint8_t* data, size_t len, void* user_data) -> int { + absl::string_view chunk = ToStringView(data, len); + return static_cast(user_data)->OnDataChunkRecv( + flags, stream_id, chunk); + }); + + nghttp2_session_callbacks_set_on_begin_frame_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame_hd* hd, void* user_data) -> int { + return static_cast(user_data)->OnBeginFrame(hd); + }); + + nghttp2_session_callbacks_set_on_frame_recv_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->OnFrameRecv( + frame); + }); + + nghttp2_session_callbacks_set_on_stream_close_callback( + callbacks, + [](nghttp2_session*, int32_t stream_id, uint32_t error_code, + void* user_data) -> int { + return static_cast(user_data)->OnStreamClose( + stream_id, error_code); + }); + + nghttp2_session_callbacks_set_on_frame_send_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->OnFrameSend( + frame); + }); + + nghttp2_session_callbacks_set_before_frame_send_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, void* user_data) -> int { + return static_cast(user_data)->BeforeFrameSend( + frame); + }); + + nghttp2_session_callbacks_set_on_frame_not_send_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, int lib_error_code, + void* user_data) -> int { + return static_cast(user_data)->OnFrameNotSend( + frame, lib_error_code); + }); + + nghttp2_session_callbacks_set_on_invalid_frame_recv_callback( + callbacks, + [](nghttp2_session*, const nghttp2_frame* frame, int error_code, + void* user_data) -> int { + return static_cast(user_data) + ->OnInvalidFrameRecv(frame, error_code); + }); + + nghttp2_session_callbacks_set_error_callback2( + callbacks, + [](nghttp2_session* /*session*/, int lib_error_code, const char* msg, + size_t len, void* user_data) -> int { + return static_cast(user_data)->OnErrorCallback2( + lib_error_code, msg, len); + }); + + nghttp2_session_callbacks_set_pack_extension_callback( + callbacks, + [](nghttp2_session*, uint8_t* buf, size_t len, const nghttp2_frame* frame, + void* user_data) -> ssize_t { + return static_cast(user_data)->OnPackExtension( + buf, len, frame); + }); + return MakeCallbacksPtr(callbacks); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/mock_nghttp2_callbacks.h b/gquiche/http2/adapter/mock_nghttp2_callbacks.h new file mode 100644 index 00000000..7665ff3c --- /dev/null +++ b/gquiche/http2/adapter/mock_nghttp2_callbacks.h @@ -0,0 +1,86 @@ +#ifndef QUICHE_HTTP2_ADAPTER_MOCK_NGHTTP2_CALLBACKS_H_ +#define QUICHE_HTTP2_ADAPTER_MOCK_NGHTTP2_CALLBACKS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "gquiche/http2/adapter/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2_util.h" +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// This class provides a set of mock nghttp2 callbacks for use in unit test +// expectations. +class QUICHE_NO_EXPORT MockNghttp2Callbacks { + public: + MockNghttp2Callbacks() = default; + + // The caller takes ownership of the |nghttp2_session_callbacks|. + static nghttp2_session_callbacks_unique_ptr GetCallbacks(); + + MOCK_METHOD(ssize_t, + Send, + (const uint8_t* data, size_t length, int flags), + ()); + + MOCK_METHOD(int, + SendData, + (nghttp2_frame * frame, + const uint8_t* framehd, + size_t length, + nghttp2_data_source* source), + ()); + + MOCK_METHOD(int, OnBeginHeaders, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, + OnHeader, + (const nghttp2_frame* frame, + absl::string_view name, + absl::string_view value, + uint8_t flags), + ()); + + MOCK_METHOD(int, + OnDataChunkRecv, + (uint8_t flags, int32_t stream_id, absl::string_view data), + ()); + + MOCK_METHOD(int, OnBeginFrame, (const nghttp2_frame_hd* hd), ()); + + MOCK_METHOD(int, OnFrameRecv, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, OnStreamClose, (int32_t stream_id, uint32_t error_code), ()); + + MOCK_METHOD(int, BeforeFrameSend, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, OnFrameSend, (const nghttp2_frame* frame), ()); + + MOCK_METHOD(int, + OnFrameNotSend, + (const nghttp2_frame* frame, int lib_error_code), + ()); + + MOCK_METHOD(int, + OnInvalidFrameRecv, + (const nghttp2_frame* frame, int error_code), + ()); + + MOCK_METHOD(int, + OnErrorCallback2, + (int lib_error_code, const char* msg, size_t len), + ()); + + MOCK_METHOD(ssize_t, OnPackExtension, + (uint8_t * buf, size_t len, const nghttp2_frame* frame), ()); +}; + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_MOCK_NGHTTP2_CALLBACKS_H_ diff --git a/gquiche/http2/adapter/nghttp2.h b/gquiche/http2/adapter/nghttp2.h new file mode 100644 index 00000000..b8487c0f --- /dev/null +++ b/gquiche/http2/adapter/nghttp2.h @@ -0,0 +1,11 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_H_ + +#include + +// Required to build on Windows. +using ssize_t = ptrdiff_t; + +#include "third_party/nghttp2/src/lib/includes/nghttp2/nghttp2.h" + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_H_ diff --git a/gquiche/http2/adapter/nghttp2_adapter.cc b/gquiche/http2/adapter/nghttp2_adapter.cc new file mode 100644 index 00000000..26a9ea75 --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_adapter.cc @@ -0,0 +1,301 @@ +#include "gquiche/http2/adapter/nghttp2_adapter.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/adapter/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2_callbacks.h" +#include "gquiche/http2/adapter/nghttp2_data_provider.h" +#include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/common/quiche_endian.h" + +namespace http2 { +namespace adapter { + +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; + +// A metadata source that deletes itself upon completion. +class SelfDeletingMetadataSource : public MetadataSource { + public: + explicit SelfDeletingMetadataSource(std::unique_ptr source) + : source_(std::move(source)) {} + + size_t NumFrames(size_t max_frame_size) const override { + return source_->NumFrames(max_frame_size); + } + + std::pair Pack(uint8_t* dest, size_t dest_len) override { + const auto result = source_->Pack(dest, dest_len); + if (result.first < 0 || result.second) { + delete this; + } + return result; + } + + private: + std::unique_ptr source_; +}; + +} // anonymous namespace + +/* static */ +std::unique_ptr NgHttp2Adapter::CreateClientAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options) { + auto adapter = new NgHttp2Adapter(visitor, Perspective::kClient, options); + adapter->Initialize(); + return absl::WrapUnique(adapter); +} + +/* static */ +std::unique_ptr NgHttp2Adapter::CreateServerAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options) { + auto adapter = new NgHttp2Adapter(visitor, Perspective::kServer, options); + adapter->Initialize(); + return absl::WrapUnique(adapter); +} + +bool NgHttp2Adapter::IsServerSession() const { + int result = nghttp2_session_check_server_session(session_->raw_ptr()); + QUICHE_DCHECK_EQ(perspective_ == Perspective::kServer, result > 0); + return result > 0; +} + +int64_t NgHttp2Adapter::ProcessBytes(absl::string_view bytes) { + const int64_t processed_bytes = session_->ProcessBytes(bytes); + if (processed_bytes < 0) { + visitor_.OnConnectionError(ConnectionError::kParseError); + } + return processed_bytes; +} + +void NgHttp2Adapter::SubmitSettings(absl::Span settings) { + // Submit SETTINGS, converting each Http2Setting to an nghttp2_settings_entry. + std::vector nghttp2_settings; + absl::c_transform(settings, std::back_inserter(nghttp2_settings), + [](const Http2Setting& setting) { + return nghttp2_settings_entry{setting.id, setting.value}; + }); + nghttp2_submit_settings(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + nghttp2_settings.data(), nghttp2_settings.size()); +} + +void NgHttp2Adapter::SubmitPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, + bool exclusive) { + nghttp2_priority_spec priority_spec; + nghttp2_priority_spec_init(&priority_spec, parent_stream_id, weight, + static_cast(exclusive)); + nghttp2_submit_priority(session_->raw_ptr(), NGHTTP2_FLAG_NONE, stream_id, + &priority_spec); +} + +void NgHttp2Adapter::SubmitPing(Http2PingId ping_id) { + uint8_t opaque_data[8] = {}; + Http2PingId ping_id_to_serialize = quiche::QuicheEndian::HostToNet64(ping_id); + std::memcpy(opaque_data, &ping_id_to_serialize, sizeof(Http2PingId)); + nghttp2_submit_ping(session_->raw_ptr(), NGHTTP2_FLAG_NONE, opaque_data); +} + +void NgHttp2Adapter::SubmitShutdownNotice() { + nghttp2_submit_shutdown_notice(session_->raw_ptr()); +} + +void NgHttp2Adapter::SubmitGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) { + nghttp2_submit_goaway(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + last_accepted_stream_id, + static_cast(error_code), + ToUint8Ptr(opaque_data.data()), opaque_data.size()); +} + +void NgHttp2Adapter::SubmitWindowUpdate(Http2StreamId stream_id, + int window_increment) { + nghttp2_submit_window_update(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + stream_id, window_increment); +} + +void NgHttp2Adapter::SubmitMetadata(Http2StreamId stream_id, + size_t max_frame_size, + std::unique_ptr source) { + auto* wrapped_source = new SelfDeletingMetadataSource(std::move(source)); + const size_t num_frames = wrapped_source->NumFrames(max_frame_size); + size_t num_successes = 0; + for (size_t i = 1; i <= num_frames; ++i) { + const int result = nghttp2_submit_extension( + session_->raw_ptr(), kMetadataFrameType, + i == num_frames ? kMetadataEndFlag : 0, stream_id, wrapped_source); + if (result != 0) { + QUICHE_LOG(DFATAL) << "Failed to submit extension frame " << i << " of " + << num_frames; + break; + } + ++num_successes; + } + if (num_successes == 0) { + delete wrapped_source; + } +} + +int NgHttp2Adapter::Send() { + const int result = nghttp2_session_send(session_->raw_ptr()); + if (result != 0) { + QUICHE_VLOG(1) << "nghttp2_session_send returned " << result; + visitor_.OnConnectionError(ConnectionError::kSendError); + } + return result; +} + +int NgHttp2Adapter::GetSendWindowSize() const { + return session_->GetRemoteWindowSize(); +} + +int NgHttp2Adapter::GetStreamSendWindowSize(Http2StreamId stream_id) const { + return nghttp2_session_get_stream_remote_window_size(session_->raw_ptr(), + stream_id); +} + +int NgHttp2Adapter::GetStreamReceiveWindowLimit(Http2StreamId stream_id) const { + return nghttp2_session_get_stream_effective_local_window_size( + session_->raw_ptr(), stream_id); +} + +int NgHttp2Adapter::GetStreamReceiveWindowSize(Http2StreamId stream_id) const { + return nghttp2_session_get_stream_local_window_size(session_->raw_ptr(), + stream_id); +} + +int NgHttp2Adapter::GetReceiveWindowSize() const { + return nghttp2_session_get_local_window_size(session_->raw_ptr()); +} + +int NgHttp2Adapter::GetHpackEncoderDynamicTableSize() const { + return nghttp2_session_get_hd_deflate_dynamic_table_size(session_->raw_ptr()); +} + +int NgHttp2Adapter::GetHpackDecoderDynamicTableSize() const { + return nghttp2_session_get_hd_inflate_dynamic_table_size(session_->raw_ptr()); +} + +Http2StreamId NgHttp2Adapter::GetHighestReceivedStreamId() const { + return nghttp2_session_get_last_proc_stream_id(session_->raw_ptr()); +} + +void NgHttp2Adapter::MarkDataConsumedForStream(Http2StreamId stream_id, + size_t num_bytes) { + int rc = session_->Consume(stream_id, num_bytes); + if (rc != 0) { + QUICHE_LOG(ERROR) << "Error " << rc << " marking " << num_bytes + << " bytes consumed for stream " << stream_id; + } +} + +void NgHttp2Adapter::SubmitRst(Http2StreamId stream_id, + Http2ErrorCode error_code) { + int status = + nghttp2_submit_rst_stream(session_->raw_ptr(), NGHTTP2_FLAG_NONE, + stream_id, static_cast(error_code)); + if (status < 0) { + QUICHE_LOG(WARNING) << "Reset stream failed: " << stream_id + << " with status code " << status; + } +} + +int32_t NgHttp2Adapter::SubmitRequest( + absl::Span headers, + std::unique_ptr data_source, void* stream_user_data) { + auto nvs = GetNghttp2Nvs(headers); + std::unique_ptr provider = + MakeDataProvider(data_source.get()); + + int32_t stream_id = + nghttp2_submit_request(session_->raw_ptr(), nullptr, nvs.data(), + nvs.size(), provider.get(), stream_user_data); + // TODO(birenroy): clean up data source on stream close + sources_.emplace(stream_id, std::move(data_source)); + QUICHE_VLOG(1) << "Submitted request with " << nvs.size() + << " request headers and user data " << stream_user_data + << "; resulted in stream " << stream_id; + return stream_id; +} + +int NgHttp2Adapter::SubmitResponse( + Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) { + auto nvs = GetNghttp2Nvs(headers); + std::unique_ptr provider = + MakeDataProvider(data_source.get()); + + // TODO(birenroy): clean up data source on stream close + sources_.emplace(stream_id, std::move(data_source)); + + int result = nghttp2_submit_response(session_->raw_ptr(), stream_id, + nvs.data(), nvs.size(), provider.get()); + QUICHE_VLOG(1) << "Submitted response with " << nvs.size() + << " response headers; result = " << result; + return result; +} + +int NgHttp2Adapter::SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) { + auto nvs = GetNghttp2Nvs(trailers); + int result = nghttp2_submit_trailer(session_->raw_ptr(), stream_id, + nvs.data(), nvs.size()); + QUICHE_VLOG(1) << "Submitted trailers with " << nvs.size() + << " response trailers; result = " << result; + return result; +} + +void NgHttp2Adapter::SetStreamUserData(Http2StreamId stream_id, + void* stream_user_data) { + nghttp2_session_set_stream_user_data(session_->raw_ptr(), stream_id, + stream_user_data); +} + +void* NgHttp2Adapter::GetStreamUserData(Http2StreamId stream_id) { + return nghttp2_session_get_stream_user_data(session_->raw_ptr(), stream_id); +} + +bool NgHttp2Adapter::ResumeStream(Http2StreamId stream_id) { + return 0 == nghttp2_session_resume_data(session_->raw_ptr(), stream_id); +} + +NgHttp2Adapter::NgHttp2Adapter(Http2VisitorInterface& visitor, + Perspective perspective, + const nghttp2_option* options) + : Http2Adapter(visitor), + visitor_(visitor), + options_(options), + perspective_(perspective) {} + +NgHttp2Adapter::~NgHttp2Adapter() {} + +void NgHttp2Adapter::Initialize() { + nghttp2_option* owned_options = nullptr; + if (options_ == nullptr) { + nghttp2_option_new(&owned_options); + // Set some common options for compatibility. + nghttp2_option_set_no_closed_streams(owned_options, 1); + nghttp2_option_set_no_auto_window_update(owned_options, 1); + nghttp2_option_set_max_send_header_block_length(owned_options, 0x2000000); + nghttp2_option_set_max_outbound_ack(owned_options, 10000); + nghttp2_option_set_user_recv_extension_type(owned_options, + kMetadataFrameType); + options_ = owned_options; + } + + session_ = absl::make_unique(perspective_, + callbacks::Create(), options_, + static_cast(&visitor_)); + if (owned_options != nullptr) { + nghttp2_option_del(owned_options); + } + options_ = nullptr; +} + +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_adapter.h b/gquiche/http2/adapter/nghttp2_adapter.h new file mode 100644 index 00000000..6be8d45a --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_adapter.h @@ -0,0 +1,110 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_ADAPTER_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_ADAPTER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "gquiche/http2/adapter/http2_adapter.h" +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/nghttp2_session.h" +#include "gquiche/http2/adapter/nghttp2_util.h" +#include "gquiche/common/platform/api/quiche_export.h" + +namespace http2 { +namespace adapter { + +class QUICHE_EXPORT_PRIVATE NgHttp2Adapter : public Http2Adapter { + public: + ~NgHttp2Adapter() override; + + // Creates an adapter that functions as a client. Does not take ownership of + // |options|. + static std::unique_ptr CreateClientAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options = nullptr); + + // Creates an adapter that functions as a server. Does not take ownership of + // |options|. + static std::unique_ptr CreateServerAdapter( + Http2VisitorInterface& visitor, const nghttp2_option* options = nullptr); + + bool IsServerSession() const override; + bool want_read() const override { return session_->want_read(); } + bool want_write() const override { return session_->want_write(); } + + int64_t ProcessBytes(absl::string_view bytes) override; + void SubmitSettings(absl::Span settings) override; + void SubmitPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, + bool exclusive) override; + + // Submits a PING on the connection. Note that nghttp2 automatically submits + // PING acks upon receiving non-ack PINGs from the peer, so callers only use + // this method to originate PINGs. See nghttp2_option_set_no_auto_ping_ack(). + void SubmitPing(Http2PingId ping_id) override; + + void SubmitShutdownNotice() override; + void SubmitGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) override; + + void SubmitWindowUpdate(Http2StreamId stream_id, + int window_increment) override; + + void SubmitRst(Http2StreamId stream_id, Http2ErrorCode error_code) override; + + void SubmitMetadata(Http2StreamId stream_id, size_t max_frame_size, + std::unique_ptr source) override; + + int Send() override; + + int GetSendWindowSize() const override; + int GetStreamSendWindowSize(Http2StreamId stream_id) const override; + + int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const override; + int GetStreamReceiveWindowSize(Http2StreamId stream_id) const override; + int GetReceiveWindowSize() const override; + + int GetHpackEncoderDynamicTableSize() const override; + int GetHpackDecoderDynamicTableSize() const override; + + Http2StreamId GetHighestReceivedStreamId() const override; + + void MarkDataConsumedForStream(Http2StreamId stream_id, + size_t num_bytes) override; + + int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data) override; + + int SubmitResponse(Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) override; + + int SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) override; + + void SetStreamUserData(Http2StreamId stream_id, void* user_data) override; + void* GetStreamUserData(Http2StreamId stream_id) override; + + bool ResumeStream(Http2StreamId stream_id) override; + + private: + NgHttp2Adapter(Http2VisitorInterface& visitor, Perspective perspective, + const nghttp2_option* options); + + // Performs any necessary initialization of the underlying HTTP/2 session, + // such as preparing initial SETTINGS. + void Initialize(); + + std::unique_ptr session_; + Http2VisitorInterface& visitor_; + const nghttp2_option* options_; + Perspective perspective_; + + absl::flat_hash_map> sources_; +}; + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_ADAPTER_H_ diff --git a/gquiche/http2/adapter/nghttp2_adapter_test.cc b/gquiche/http2/adapter/nghttp2_adapter_test.cc new file mode 100644 index 00000000..a5e46ef3 --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_adapter_test.cc @@ -0,0 +1,3408 @@ +#include "gquiche/http2/adapter/nghttp2_adapter.h" + +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/adapter/mock_http2_visitor.h" +#include "gquiche/http2/adapter/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2_test_utils.h" +#include "gquiche/http2/adapter/oghttp2_util.h" +#include "gquiche/http2/adapter/test_frame_sequence.h" +#include "gquiche/http2/adapter/test_utils.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; + +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; + +// This send callback assumes |source|'s pointer is a TestDataSource, and +// |user_data| is a Http2VisitorInterface. +int TestSendCallback(nghttp2_session*, nghttp2_frame* /*frame*/, + const uint8_t* framehd, size_t length, + nghttp2_data_source* source, void* user_data) { + auto* visitor = static_cast(user_data); + // Send the frame header via the visitor. + ssize_t result = visitor->OnReadyToSend(ToStringView(framehd, 9)); + if (result == 0) { + return NGHTTP2_ERR_WOULDBLOCK; + } + auto* test_source = static_cast(source->ptr); + absl::string_view payload = test_source->ReadNext(length); + // Send the frame payload via the visitor. + visitor->OnReadyToSend(payload); + return 0; +} + +TEST(NgHttp2AdapterTest, ClientConstruction) { + testing::StrictMock visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + ASSERT_NE(nullptr, adapter); + EXPECT_TRUE(adapter->want_read()); + EXPECT_FALSE(adapter->want_write()); + EXPECT_FALSE(adapter->IsServerSession()); +} + +TEST(NgHttp2AdapterTest, ClientHandlesFrames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_EQ(adapter->GetSendWindowSize(), kInitialFlowControlWindowSize + 1000); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x1, 0)); + + result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING})); + visitor.Clear(); + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const std::vector headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + + const std::vector headers3 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const char* kSentinel3 = "arbitrary pointer 3"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id2; + + const int32_t stream_id3 = + adapter->SubmitRequest(headers3, nullptr, const_cast(kSentinel3)); + ASSERT_GT(stream_id3, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id3; + + const char* kSentinel2 = "arbitrary pointer 2"; + adapter->SetStreamUserData(stream_id2, const_cast(kSentinel2)); + adapter->SetStreamUserData(stream_id3, nullptr); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id3, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id3, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + // All streams are active and have not yet received any data, so the receive + // window should be at the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id1)); + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id2)); + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id3)); + + // Upper bound on the flow control receive window should be the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(stream_id1)); + + // Connection has not yet received any data. + EXPECT_EQ(kInitialFlowControlWindowSize, adapter->GetReceiveWindowSize()); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_EQ(kSentinel1, adapter->GetStreamUserData(stream_id1)); + EXPECT_EQ(kSentinel2, adapter->GetStreamUserData(stream_id2)); + EXPECT_EQ(nullptr, adapter->GetStreamUserData(stream_id3)); + + EXPECT_EQ(0, adapter->GetHpackDecoderDynamicTableSize()); + + const std::string stream_frames = + TestFrameSequence() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .RstStream(3, Http2ErrorCode::INTERNAL_ERROR) + .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(3, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor, OnFrameHeader(0, 19, GOAWAY, 0)); + EXPECT_CALL(visitor, + OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!")); + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + // First stream has received some data. + EXPECT_GT(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id1)); + // Second stream was closed. + EXPECT_EQ(-1, adapter->GetStreamReceiveWindowSize(stream_id2)); + // Third stream has not received any data. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id3)); + + // Connection window should be the same as the first stream. + EXPECT_EQ(adapter->GetReceiveWindowSize(), + adapter->GetStreamReceiveWindowSize(stream_id1)); + + // Upper bound on the flow control receive window should still be the initial + // value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(stream_id1)); + + EXPECT_GT(adapter->GetHpackDecoderDynamicTableSize(), 0); + + // Should be 3, but this method only works for server adapters. + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + // Even though the client recieved a GOAWAY, streams 1 and 5 are still active. + EXPECT_TRUE(adapter->want_read()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 0, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 0)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnFrameHeader(5, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(5, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor, OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM)); + adapter->ProcessBytes(TestFrameSequence() + .Data(1, "", true) + .RstStream(5, Http2ErrorCode::REFUSED_STREAM) + .Serialize()); + + // Should be 5, but this method only works for server adapters. + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + // After receiving END_STREAM for 1 and RST_STREAM for 5, the session no + // longer expects reads. + EXPECT_FALSE(adapter->want_read()); + + // Client will not have anything else to write. + EXPECT_FALSE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); +} + +TEST(NgHttp2AdapterTest, ClientRejects100HeadersWithFin) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, /*fin=*/false) + .Headers(1, {{":status", "100"}}, /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, _, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{"final-status", "A-OK"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, "final-status", "A-OK")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesMetadata) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "Example stream metadata") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesMetadataWithError) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "Example stream metadata") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)) + .WillOnce(testing::Return(false)); + // Remaining frames are not processed due to the error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + // The false return from OnMetadataForStream() results in a connection error. + EXPECT_EQ(stream_result, NGHTTP2_ERR_CALLBACK_FAILURE); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_TRUE(adapter->want_read()); // Even after an error. Why? + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientHandlesHpackHeaderTableSetting) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"x-i-do-not-like", "green eggs and ham"}, + {"x-i-will-not-eat-them", "here or there, in a box, with a fox"}, + {"x-like-them-in-a-house", "no"}, + {"x-like-them-with-a-mouse", "no"}, + }); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 100); + + const std::string stream_frames = + TestFrameSequence().Settings({{HEADER_TABLE_SIZE, 100u}}).Serialize(); + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 100u})); + + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_LE(adapter->GetHpackEncoderDynamicTableSize(), 100); +} + +TEST(NgHttp2AdapterTest, ClientHandlesInvalidTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{":bad-status", "9000"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [:bad-status], value: [9000]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + // Bad status trailer will cause a PROTOCOL_ERROR. The header is never + // delivered in an OnHeaderForStream callback. + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientRstStreamWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce(testing::DoAll( + testing::InvokeWithoutArgs([&adapter]() { + adapter->SubmitRst(1, Http2ErrorCode::REFUSED_STREAM); + }), + testing::Return(Http2VisitorInterface::HEADER_RST_STREAM))); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::REFUSED_STREAM)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + // Translation to nghttp2 treats this error as a general parsing error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(-902 /* NGHTTP2_ERR_CALLBACK_FAILURE */, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientConnectionErrorWhileHandlingHeadersOnly) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + // Translation to nghttp2 treats this error as a general parsing error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(-902 /* NGHTTP2_ERR_CALLBACK_FAILURE */, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientRejectsHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)) + .WillOnce(testing::Return(false)); + // Rejecting headers leads to a connection error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientFailsOnGoAway) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + EXPECT_CALL(visitor, + OnGoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion")) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ClientRejects101Response) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"upgrade", "new-protocol"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "101"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL( + visitor, + OnErrorDebug("Invalid HTTP header field was received: frame type: 1, " + "stream: 1, name: [:status], value: [101]")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_frames.size()), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ClientSubmitRequest) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_EQ(0, adapter->GetHpackEncoderDynamicTableSize()); + EXPECT_FALSE(adapter->want_write()); + const char* kSentinel = ""; + const absl::string_view kBody = "This is an example request body."; + auto body1 = absl::make_unique(visitor, true); + body1->AppendPayload(kBody); + body1->EndData(); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(stream_id)); + EXPECT_EQ(kInitialFlowControlWindowSize, adapter->GetReceiveWindowSize()); + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(stream_id)); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 0); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(adapter->GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); + EXPECT_GT(adapter->GetStreamSendWindowSize(stream_id), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(-1, adapter->GetStreamSendWindowSize(stream_id + 2)); + + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + const char* kSentinel2 = "arbitrary pointer 2"; + EXPECT_EQ(nullptr, adapter->GetStreamUserData(stream_id)); + adapter->SetStreamUserData(stream_id, const_cast(kSentinel2)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + + EXPECT_EQ(kSentinel2, adapter->GetStreamUserData(stream_id)); + + // No data was sent (just HEADERS), so the remaining send window size should + // still be the default. + EXPECT_EQ(adapter->GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); +} + +// This is really a test of the MakeZeroCopyDataFrameSource adapter, but I +// wasn't sure where else to put it. +TEST(NgHttp2AdapterTest, ClientSubmitRequestWithDataProvider) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example request body."; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + EXPECT_FALSE(adapter->want_write()); +} + +// This test verifies how nghttp2 behaves when a data source becomes +// read-blocked. +TEST(NgHttp2AdapterTest, ClientSubmitRequestWithDataProviderAndReadBlock) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const absl::string_view kBody = "This is an example request body."; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kBody}; + body1.set_is_data_available(false); + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // Resume the deferred stream. + body1.set_is_data_available(true); + EXPECT_TRUE(adapter->ResumeStream(stream_id)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::DATA})); + EXPECT_FALSE(adapter->want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(adapter->ResumeStream(stream_id)); + EXPECT_FALSE(adapter->want_write()); +} + +// This test verifies how nghttp2 behaves when a data source is read block, then +// ends with an empty DATA frame. +TEST(NgHttp2AdapterTest, ClientSubmitRequestEmptyDataWithFin) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const absl::string_view kEmptyBody = ""; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kEmptyBody}; + body1.set_is_data_available(false); + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // Resume the deferred stream. + body1.set_is_data_available(true); + EXPECT_TRUE(adapter->ResumeStream(stream_id)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::DATA})); + EXPECT_FALSE(adapter->want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(adapter->ResumeStream(stream_id)); + EXPECT_FALSE(adapter->want_write()); +} + +// This test verifies how nghttp2 behaves when a connection becomes +// write-blocked. +TEST(NgHttp2AdapterTest, ClientSubmitRequestWithDataProviderAndWriteBlock) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const absl::string_view kBody = "This is an example request body."; + // This test will use TestDataSource as the source of the body payload data. + TestDataSource body1{kBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + nghttp2_send_data_callback send_callback = &TestSendCallback; + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &visitor, std::move(send_callback)); + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(frame_source), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + visitor.set_is_write_blocked(true); + int result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + visitor.set_is_write_blocked(false); + result = adapter->Send(); + EXPECT_EQ(0, result); + + // Client preface does not appear to include the mandatory SETTINGS frame. + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientReceivesDataOnClosedStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + // Client SETTINGS ack + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client open a stream with a request. + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Let the client RST_STREAM the stream it opened. + adapter->SubmitRst(stream_id, Http2ErrorCode::CANCEL); + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id, _, 0x0, + static_cast(Http2ErrorCode::CANCEL))); + EXPECT_CALL(visitor, OnCloseStream(stream_id, Http2ErrorCode::CANCEL)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM})); + visitor.Clear(); + + // Let the server send a response on the stream. (It might not have received + // the RST_STREAM yet.) + const std::string response_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + // The visitor gets notified about the HEADERS frame but not the DATA frame on + // the closed stream. No further processing for either frame occurs. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, DATA, _)).Times(0); + + const int64_t response_result = adapter->ProcessBytes(response_frames); + EXPECT_EQ(response_frames.size(), response_result); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, SubmitMetadata) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + auto source = absl::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT( + serialized, + EqualsFrames({static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, SubmitMetadataMultipleFrames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + const auto kLargeValue = std::string(63 * 1024, 'a'); + auto source = absl::make_unique( + ToHeaderBlock(ToHeaders({{"large-value", kLargeValue}}))); + adapter->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + testing::InSequence seq; + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT( + serialized, + EqualsFrames({static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, SubmitConnectionMetadata) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + auto source = absl::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter->SubmitMetadata(0, 16384u, std::move(source)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(kMetadataFrameType, 0, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(kMetadataFrameType, 0, _, 0x4, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT( + serialized, + EqualsFrames({static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientObeysMaxConcurrentStreams) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + int result = adapter->Send(); + EXPECT_EQ(0, result); + // Client preface does not appear to include the mandatory SETTINGS frame. + EXPECT_THAT(visitor.data(), + testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{.id = MAX_CONCURRENT_STREAMS, .value = 1}}) + .Serialize(); + testing::InSequence s; + + // Server preface (SETTINGS with MAX_CONCURRENT_STREAMS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example request body."; + auto body1 = absl::make_unique(visitor, true); + body1->AppendPayload(kBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + const int next_stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + + // A new pending stream is created, but because of MAX_CONCURRENT_STREAMS, the + // session should not want to write it at the moment. + EXPECT_GT(next_stream_id, stream_id); + EXPECT_FALSE(adapter->want_write()); + + const std::string stream_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); + EXPECT_CALL(visitor, + OnHeaderForStream(stream_id, "server", "my-fake-server")); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, "date", + "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 26, DATA, 0x1)); + EXPECT_CALL(visitor, OnBeginDataForStream(stream_id, 26)); + EXPECT_CALL(visitor, + OnDataForStream(stream_id, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(stream_id)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + // The first stream should close, which should make the session want to write + // the next stream. + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, next_stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, next_stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ClientForbidsPushPromise) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + adapter->SubmitSettings({{ENABLE_PUSH, 0}}); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::vector push_headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/push"}}); + const std::string frames = TestFrameSequence() + .ServerPreface() + .SettingsAck() + .PushPromise(stream_id, 2, push_headers) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The PUSH_PROMISE is now treated as an invalid frame. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, PUSH_PROMISE, _)); + EXPECT_CALL(visitor, OnInvalidFrame(stream_id, _)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), read_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); +} + +TEST(NgHttp2AdapterTest, ClientForbidsPushStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + adapter->SubmitSettings({{ENABLE_PUSH, 0}}); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers(2, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The push HEADERS are invalid. + EXPECT_CALL(visitor, OnFrameHeader(2, _, HEADERS, _)); + EXPECT_CALL(visitor, OnInvalidFrame(2, _)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), read_result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); +} + +TEST(NgHttp2AdapterTest, FailureSendingConnectionPreface) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateClientAdapter(visitor); + + visitor.set_has_write_error(); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int result = adapter->Send(); + EXPECT_EQ(result, NGHTTP2_ERR_CALLBACK_FAILURE); +} + +TEST(NgHttp2AdapterTest, ServerConstruction) { + testing::StrictMock visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + ASSERT_NE(nullptr, adapter); + EXPECT_TRUE(adapter->want_read()); + EXPECT_FALSE(adapter->want_write()); + EXPECT_TRUE(adapter->IsServerSession()); +} + +TEST(NgHttp2AdapterTest, ServerHandlesFrames) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + EXPECT_EQ(0, adapter->GetHpackDecoderDynamicTableSize()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&adapter, kSentinel1]() { + adapter->SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "http")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/two")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnEndStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(47, false)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(kSentinel1, adapter->GetStreamUserData(1)); + + EXPECT_GT(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowSize(1)); + EXPECT_EQ(adapter->GetStreamReceiveWindowSize(1), + adapter->GetReceiveWindowSize()); + // Upper bound should still be the original value. + EXPECT_EQ(kInitialFlowControlWindowSize, + adapter->GetStreamReceiveWindowLimit(1)); + + EXPECT_GT(adapter->GetHpackDecoderDynamicTableSize(), 0); + + // Because stream 3 has already been closed, it's not possible to set user + // data. + const char* kSentinel3 = "another arbitrary pointer"; + adapter->SetStreamUserData(3, const_cast(kSentinel3)); + EXPECT_EQ(nullptr, adapter->GetStreamUserData(3)); + + EXPECT_EQ(3, adapter->GetHighestReceivedStreamId()); + + EXPECT_EQ(adapter->GetSendWindowSize(), kInitialFlowControlWindowSize + 1000); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack, two PING acks. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING, + spdy::SpdyFrameType::PING})); +} + +// Tests the case where the response body is in the progress of being sent while +// trailers are queued. +TEST(NgHttp2AdapterTest, ServerSubmitsTrailersWhileDataDeferred) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload(kBody); + auto* body1_ptr = body1.get(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{"final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + + // Even though the data source has not finished sending data, nghttp2 will + // write the trailers anyway. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Resuming the stream results in the library wanting to write again. + body1_ptr->AppendPayload(kBody); + body1_ptr->EndData(); + adapter->ResumeStream(1); + EXPECT_TRUE(adapter->want_write()); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + + // But no data is written for the stream. + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ServerErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "some bogus value!")) + .WillOnce(testing::Return(Http2VisitorInterface::HEADER_RST_STREAM)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + // DATA frame is not delivered to the visitor. + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"Accept", "uppercase, oh boy!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnErrorDebug); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)) + .WillOnce(testing::Return(false)); + // Translation to nghttp2 treats this error as a general parsing error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(result, NGHTTP2_ERR_CALLBACK_FAILURE); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack and RST_STREAM + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, ServerErrorAfterHandlingHeaders) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(-902, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +// Exercises the case when a visitor chooses to reject a frame based solely on +// the frame header, which is a fatal error for the connection. +TEST(NgHttp2AdapterTest, ServerRejectsFrameHeader) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(64) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(-902, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerRejectsBeginningOfData) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerRejectsStreamData) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor, OnDataForStream(1, _)).WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, result); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerSubmitResponse) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&adapter, kSentinel1]() { + adapter->SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + // Server will want to send a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_EQ(0, adapter->GetHpackEncoderDynamicTableSize()); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example response body."; + // A data fin is not sent so that the stream remains open, and the flow + // control state can be verified. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload(kBody); + int submit_result = adapter->SubmitResponse( + 1, + ToHeaders({{":status", "404"}, + {"x-comment", "I have no idea what you're talking about."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + // Stream user data should have been set successfully after receiving headers. + EXPECT_EQ(kSentinel1, adapter->GetStreamUserData(1)); + adapter->SetStreamUserData(1, nullptr); + EXPECT_EQ(nullptr, adapter->GetStreamUserData(1)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + EXPECT_FALSE(adapter->want_write()); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(adapter->GetStreamSendWindowSize(1), kInitialFlowControlWindowSize); + EXPECT_GT(adapter->GetStreamSendWindowSize(1), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(adapter->GetStreamSendWindowSize(3), -1); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 0); +} + +// Should also test: client attempts shutdown, server attempts shutdown after an +// explicit GOAWAY. +TEST(NgHttp2AdapterTest, ServerSendsShutdown) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + adapter->SubmitShutdownNotice(); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerSendsTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + // Server will want to send a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(adapter->want_write()); + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = adapter->SubmitTrailer( + 1, ToHeaders({{"final-status", "a-ok"}, + {"x-comment", "trailers sure are cool"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); +} + +TEST(NgHttp2AdapterTest, ClientSendsContinuation) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true, + /*add_continuation=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 1)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); +} + +TEST(NgHttp2AdapterTest, ClientSendsMetadataWithContinuation) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Metadata(0, "Example connection metadata in multiple frames", true) + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false, + /*add_continuation=*/true) + .Metadata(1, + "Some stream metadata that's also sent in multiple frames", + true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Metadata on stream 0 + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + // Metadata on stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + EXPECT_EQ(TestFrameSequence::MetadataBlockForPayload( + "Example connection metadata in multiple frames"), + absl::StrJoin(visitor.GetMetadata(0), "")); + EXPECT_EQ(TestFrameSequence::MetadataBlockForPayload( + "Some stream metadata that's also sent in multiple frames"), + absl::StrJoin(visitor.GetMetadata(1), "")); +} + +TEST(NgHttp2AdapterTest, ServerSubmitsResponseWithDataSourceError) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + auto body1 = absl::make_unique(visitor, false); + body1->SimulateError(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, _, 0x0, 2)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::RST_STREAM})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + // The library does not object to the user queuing trailers, even through the + // stream has already been closed. + EXPECT_EQ(trailer_result, 0); +} + +TEST(NgHttp2AdapterTest, CompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the response body.", /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, IncompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + // BUG: Should send RST_STREAM NO_ERROR as well, but nghttp2 does not. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(NgHttp2AdapterTest, ServerSendsInvalidTrailers) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); +} + +TEST(NgHttp2AdapterTest, ServerDropsNewStreamBelowWatermark) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(3, "This is the request body.") + .Headers(1, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 25)); + EXPECT_CALL(visitor, OnDataForStream(3, "This is the request body.")); + + // It looks like nghttp2 delivers the under-watermark frame header but + // otherwise silently drops the rest of the frame without error. + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnInvalidFrame).Times(0); + EXPECT_CALL(visitor, OnConnectionError).Times(0); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(3, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(NgHttp2AdapterTest, ServerForbidsWindowUpdateOnIdleStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence().ClientPreface().WindowUpdate(1, 42).Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnInvalidFrame(1, _)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // The GOAWAY apparently causes the SETTINGS ack to be dropped. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerForbidsDataOnIdleStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Data(1, "Sorry, out of order") + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + // In this case, nghttp2 goes straight to GOAWAY and does not invoke the + // invalid frame callback. + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // The GOAWAY apparently causes the SETTINGS ack to be dropped. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerForbidsRstStreamOnIdleStream) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .RstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnInvalidFrame(1, _)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), result); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // The GOAWAY apparently causes the SETTINGS ack to be dropped. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerForbidsNewStreamAboveStreamLimit) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client send a SETTINGS ack and then attempt to open more than the + // advertised number of streams. The overflow stream should be rejected. + const std::string stream_frames = + TestFrameSequence() + .SettingsAck() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5)); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kProtocol)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), stream_result); + + // The server should send a GOAWAY for this error, even though + // OnInvalidFrame() returns true. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::GOAWAY})); +} + +TEST(NgHttp2AdapterTest, ServerRstStreamsNewStreamAboveStreamLimitBeforeAck) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), initial_result); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client avoid sending a SETTINGS ack and attempt to open more than + // the advertised number of streams. The server should still reject the + // overflow stream, albeit with RST_STREAM REFUSED_STREAM instead of GOAWAY. + const std::string stream_frames = + TestFrameSequence() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, + OnInvalidFrame( + 3, Http2VisitorInterface::InvalidFrameError::kRefusedStream)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_result, stream_frames.size()); + + // The server sends a RST_STREAM for the offending stream. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(NgHttp2AdapterTest, AutomaticSettingsAndPingAcks) { + DataSavingVisitor visitor; + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor); + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface does not appear to include the mandatory SETTINGS frame. + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + // PING ack + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING})); +} + +TEST(NgHttp2AdapterTest, AutomaticPingAcksDisabled) { + DataSavingVisitor visitor; + nghttp2_option* options; + nghttp2_option_new(&options); + nghttp2_option_set_no_auto_ping_ack(options, 1); + auto adapter = NgHttp2Adapter::CreateServerAdapter(visitor, options); + nghttp2_option_del(options); + + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface does not appear to include the mandatory SETTINGS frame. + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + // No PING ack expected because automatic PING acks are disabled. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_callbacks.cc b/gquiche/http2/adapter/nghttp2_callbacks.cc index 0f339aeb..1495dbe2 100644 --- a/gquiche/http2/adapter/nghttp2_callbacks.cc +++ b/gquiche/http2/adapter/nghttp2_callbacks.cc @@ -4,11 +4,12 @@ #include #include "absl/strings/string_view.h" +#include "gquiche/http2/adapter/data_source.h" #include "gquiche/http2/adapter/http2_protocol.h" #include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/adapter/nghttp2_data_provider.h" #include "gquiche/http2/adapter/nghttp2_util.h" -#include "third_party/nghttp2/nghttp2.h" -#include "third_party/nghttp2/src/lib/includes/nghttp2/nghttp2.h" +#include "gquiche/common/platform/api/quiche_bug_tracker.h" #include "gquiche/common/platform/api/quiche_logging.h" #include "gquiche/common/quiche_endian.h" @@ -16,23 +17,46 @@ namespace http2 { namespace adapter { namespace callbacks { +ssize_t OnReadyToSend(nghttp2_session* /* session */, const uint8_t* data, + size_t length, int flags, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const int64_t result = visitor->OnReadyToSend(ToStringView(data, length)); + QUICHE_VLOG(1) << "OnReadyToSend(length=" << length << ", flags=" << flags + << ") returning " << result; + if (result > 0) { + return result; + } else if (result == Http2VisitorInterface::kSendBlocked) { + return -504; // NGHTTP2_ERR_WOULDBLOCK + } else { + return -902; // NGHTTP2_ERR_CALLBACK_FAILURE + } +} + int OnBeginFrame(nghttp2_session* /* session */, const nghttp2_frame_hd* header, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); auto* visitor = static_cast(user_data); + bool result = visitor->OnFrameHeader(header->stream_id, header->length, + header->type, header->flags); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } if (header->type == NGHTTP2_DATA) { - visitor->OnBeginDataForStream(header->stream_id, header->length); + result = visitor->OnBeginDataForStream(header->stream_id, header->length); + } else if (header->type == kMetadataFrameType) { + visitor->OnBeginMetadataForStream(header->stream_id, header->length); } - return 0; + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; } int OnFrameReceived(nghttp2_session* /* session */, const nghttp2_frame* frame, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); auto* visitor = static_cast(user_data); const Http2StreamId stream_id = frame->hd.stream_id; - QUICHE_VLOG(2) << "Frame " << static_cast(frame->hd.type) - << " for stream " << stream_id; switch (frame->hd.type) { // The beginning of the DATA frame is handled in OnBeginFrame(), and the // beginning of the header block is handled in client/server-specific @@ -45,7 +69,10 @@ int OnFrameReceived(nghttp2_session* /* session */, break; case NGHTTP2_HEADERS: { if (frame->hd.flags & NGHTTP2_FLAG_END_HEADERS) { - visitor->OnEndHeadersForStream(stream_id); + const bool result = visitor->OnEndHeadersForStream(stream_id); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } } if (frame->hd.flags & NGHTTP2_FLAG_END_STREAM) { visitor->OnEndStream(stream_id); @@ -69,12 +96,11 @@ int OnFrameReceived(nghttp2_session* /* session */, visitor->OnSettingsAck(); } else { visitor->OnSettingsStart(); - for (int i = 0; i < frame->settings.niv; ++i) { + for (size_t i = 0; i < frame->settings.niv; ++i) { nghttp2_settings_entry entry = frame->settings.iv[i]; // The nghttp2_settings_entry uses int32_t for the ID; we must cast. visitor->OnSetting(Http2Setting{ - .id = static_cast(entry.settings_id), - .value = entry.value}); + static_cast(entry.settings_id), entry.value}); } visitor->OnSettingsEnd(); } @@ -100,9 +126,12 @@ int OnFrameReceived(nghttp2_session* /* session */, absl::string_view opaque_data( reinterpret_cast(frame->goaway.opaque_data), frame->goaway.opaque_data_len); - visitor->OnGoAway(frame->goaway.last_stream_id, - ToHttp2ErrorCode(frame->goaway.error_code), - opaque_data); + const bool result = visitor->OnGoAway( + frame->goaway.last_stream_id, + ToHttp2ErrorCode(frame->goaway.error_code), opaque_data); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } break; } case NGHTTP2_WINDOW_UPDATE: { @@ -127,71 +156,148 @@ int OnFrameReceived(nghttp2_session* /* session */, int OnBeginHeaders(nghttp2_session* /* session */, const nghttp2_frame* frame, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); auto* visitor = static_cast(user_data); - visitor->OnBeginHeadersForStream(frame->hd.stream_id); - return 0; + const bool result = visitor->OnBeginHeadersForStream(frame->hd.stream_id); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; } -int OnHeader(nghttp2_session* /* session */, - const nghttp2_frame* frame, - nghttp2_rcbuf* name, - nghttp2_rcbuf* value, - uint8_t flags, +int OnHeader(nghttp2_session* /* session */, const nghttp2_frame* frame, + nghttp2_rcbuf* name, nghttp2_rcbuf* value, uint8_t /*flags*/, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); auto* visitor = static_cast(user_data); - visitor->OnHeaderForStream(frame->hd.stream_id, ToStringView(name), - ToStringView(value)); - return 0; + const Http2VisitorInterface::OnHeaderResult result = + visitor->OnHeaderForStream(frame->hd.stream_id, ToStringView(name), + ToStringView(value)); + switch (result) { + case Http2VisitorInterface::HEADER_OK: + return 0; + case Http2VisitorInterface::HEADER_CONNECTION_ERROR: + return NGHTTP2_ERR_CALLBACK_FAILURE; + case Http2VisitorInterface::HEADER_RST_STREAM: + case Http2VisitorInterface::HEADER_HTTP_MESSAGING: + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; + } + // Unexpected value. + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; +} + +int OnBeforeFrameSent(nghttp2_session* /* session */, + const nghttp2_frame* frame, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + LogBeforeSend(*frame); + auto* visitor = static_cast(user_data); + return visitor->OnBeforeFrameSent(frame->hd.type, frame->hd.stream_id, + frame->hd.length, frame->hd.flags); } -int OnDataChunk(nghttp2_session* /* session */, - uint8_t flags, - Http2StreamId stream_id, - const uint8_t* data, - size_t len, +int OnFrameSent(nghttp2_session* /* session */, const nghttp2_frame* frame, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); auto* visitor = static_cast(user_data); - visitor->OnDataForStream( + uint32_t error_code = 0; + if (frame->hd.type == NGHTTP2_RST_STREAM) { + error_code = frame->rst_stream.error_code; + } else if (frame->hd.type == NGHTTP2_GOAWAY) { + error_code = frame->goaway.error_code; + } + return visitor->OnFrameSent(frame->hd.type, frame->hd.stream_id, + frame->hd.length, frame->hd.flags, error_code); +} + +int OnInvalidFrameReceived(nghttp2_session* /* session */, + const nghttp2_frame* frame, int lib_error_code, + void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const bool result = visitor->OnInvalidFrame( + frame->hd.stream_id, ToInvalidFrameError(lib_error_code)); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnDataChunk(nghttp2_session* /* session */, uint8_t /*flags*/, + Http2StreamId stream_id, const uint8_t* data, size_t len, + void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + const bool result = visitor->OnDataForStream( stream_id, absl::string_view(reinterpret_cast(data), len)); - return 0; + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; } int OnStreamClosed(nghttp2_session* /* session */, Http2StreamId stream_id, uint32_t error_code, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); auto* visitor = static_cast(user_data); - if (error_code == static_cast(Http2ErrorCode::NO_ERROR)) { - visitor->OnCloseStream(stream_id); - } else { - visitor->OnAbortStream(stream_id, ToHttp2ErrorCode(error_code)); + visitor->OnCloseStream(stream_id, ToHttp2ErrorCode(error_code)); + return 0; +} + +int OnExtensionChunkReceived(nghttp2_session* /*session*/, + const nghttp2_frame_hd* hd, const uint8_t* data, + size_t len, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + if (hd->type != kMetadataFrameType) { + QUICHE_LOG(ERROR) << "Unexpected frame type: " + << static_cast(hd->type); + return NGHTTP2_ERR_CANCEL; + } + const bool result = + visitor->OnMetadataForStream(hd->stream_id, ToStringView(data, len)); + return result ? 0 : NGHTTP2_ERR_CALLBACK_FAILURE; +} + +int OnUnpackExtensionCallback(nghttp2_session* /*session*/, void** /*payload*/, + const nghttp2_frame_hd* hd, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + if (hd->flags == kMetadataEndFlag) { + const bool result = visitor->OnMetadataEndForStream(hd->stream_id); + if (!result) { + return NGHTTP2_ERR_CALLBACK_FAILURE; + } } return 0; } -ssize_t OnReadyToReadDataForStream(nghttp2_session* /* session */, - Http2StreamId stream_id, - uint8_t* dest_buffer, - size_t max_length, - uint32_t* data_flags, - nghttp2_data_source* source, - void* user_data) { - auto* visitor = static_cast(source->ptr); - ssize_t bytes_to_send = 0; - bool end_stream = false; - visitor->OnReadyToSendDataForStream(stream_id, - reinterpret_cast(dest_buffer), - max_length, &bytes_to_send, &end_stream); - if (bytes_to_send >= 0 && end_stream) { - *data_flags |= NGHTTP2_DATA_FLAG_EOF; +ssize_t OnPackExtensionCallback(nghttp2_session* /*session*/, uint8_t* buf, + size_t len, const nghttp2_frame* frame, + void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* source = static_cast(frame->ext.payload); + if (source == nullptr) { + QUICHE_BUG(payload_is_nullptr) << "Extension frame payload for stream " + << frame->hd.stream_id << " is null!"; + return NGHTTP2_ERR_CALLBACK_FAILURE; + } + const std::pair result = source->Pack(buf, len); + if (result.first < 0) { + return NGHTTP2_ERR_CALLBACK_FAILURE; } - return bytes_to_send; + const bool end_metadata_flag = (frame->hd.flags & kMetadataEndFlag); + QUICHE_LOG_IF(DFATAL, result.second != end_metadata_flag) + << "Metadata ends: " << result.second + << " has kMetadataEndFlag: " << end_metadata_flag; + return result.first; +} + +int OnError(nghttp2_session* /*session*/, int /*lib_error_code*/, + const char* msg, size_t len, void* user_data) { + QUICHE_CHECK_NE(user_data, nullptr); + auto* visitor = static_cast(user_data); + visitor->OnErrorDebug(absl::string_view(msg, len)); + return 0; } -nghttp2_session_callbacks* Create() { +nghttp2_session_callbacks_unique_ptr Create() { nghttp2_session_callbacks* callbacks; nghttp2_session_callbacks_new(&callbacks); + nghttp2_session_callbacks_set_send_callback(callbacks, &OnReadyToSend); nghttp2_session_callbacks_set_on_begin_frame_callback(callbacks, &OnBeginFrame); nghttp2_session_callbacks_set_on_frame_recv_callback(callbacks, @@ -203,7 +309,22 @@ nghttp2_session_callbacks* Create() { &OnDataChunk); nghttp2_session_callbacks_set_on_stream_close_callback(callbacks, &OnStreamClosed); - return callbacks; + nghttp2_session_callbacks_set_before_frame_send_callback(callbacks, + &OnBeforeFrameSent); + nghttp2_session_callbacks_set_on_frame_send_callback(callbacks, &OnFrameSent); + nghttp2_session_callbacks_set_on_invalid_frame_recv_callback( + callbacks, &OnInvalidFrameReceived); + nghttp2_session_callbacks_set_error_callback2(callbacks, &OnError); + // on_frame_not_send_callback <- just ignored + nghttp2_session_callbacks_set_send_data_callback( + callbacks, &DataFrameSourceSendCallback); + nghttp2_session_callbacks_set_pack_extension_callback( + callbacks, &OnPackExtensionCallback); + nghttp2_session_callbacks_set_unpack_extension_callback( + callbacks, &OnUnpackExtensionCallback); + nghttp2_session_callbacks_set_on_extension_chunk_recv_callback( + callbacks, &OnExtensionChunkReceived); + return MakeCallbacksPtr(callbacks); } } // namespace callbacks diff --git a/gquiche/http2/adapter/nghttp2_callbacks.h b/gquiche/http2/adapter/nghttp2_callbacks.h index 3eea6bdf..a0eddd99 100644 --- a/gquiche/http2/adapter/nghttp2_callbacks.h +++ b/gquiche/http2/adapter/nghttp2_callbacks.h @@ -1,8 +1,11 @@ #ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_CALLBACKS_H_ #define QUICHE_HTTP2_ADAPTER_NGHTTP2_CALLBACKS_H_ +#include + #include "gquiche/http2/adapter/http2_protocol.h" -#include "third_party/nghttp2/src/lib/includes/nghttp2/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2_util.h" namespace http2 { namespace adapter { @@ -12,6 +15,13 @@ namespace callbacks { // beginning of its lifetime. It is expected that |user_data| holds an // Http2VisitorInterface. +// Callback once the library is ready to send serialized frames. +ssize_t OnReadyToSend(nghttp2_session* session, + const uint8_t* data, + size_t length, + int flags, + void* user_data); + // Callback once a frame header has been received. int OnBeginFrame(nghttp2_session* session, const nghttp2_frame_hd* header, void* user_data); @@ -30,7 +40,19 @@ int OnHeader(nghttp2_session* session, const nghttp2_frame* frame, nghttp2_rcbuf* name, nghttp2_rcbuf* value, uint8_t flags, void* user_data); -// Callback once a chunk of data (from a DATA frame payload) has been received. +// Invoked immediately before sending a frame. +int OnBeforeFrameSent(nghttp2_session* session, const nghttp2_frame* frame, + void* user_data); + +// Invoked immediately after a frame is sent. +int OnFrameSent(nghttp2_session* session, const nghttp2_frame* frame, + void* user_data); + +// Invoked when an invalid frame is received. +int OnInvalidFrameReceived(nghttp2_session* session, const nghttp2_frame* frame, + int lib_error_code, void* user_data); + +// Invoked when a chunk of data (from a DATA frame payload) has been received. int OnDataChunk(nghttp2_session* session, uint8_t flags, Http2StreamId stream_id, const uint8_t* data, size_t len, void* user_data); @@ -39,15 +61,27 @@ int OnDataChunk(nghttp2_session* session, uint8_t flags, int OnStreamClosed(nghttp2_session* session, Http2StreamId stream_id, uint32_t error_code, void* user_data); -// Callback once nghttp2 is ready to read data from |source| into |dest_buffer|. -ssize_t OnReadyToReadDataForStream(nghttp2_session* session, - Http2StreamId stream_id, - uint8_t* dest_buffer, size_t max_length, - uint32_t* data_flags, - nghttp2_data_source* source, - void* user_data); +// Invoked when nghttp2 has a chunk of extension frame data to pass to the +// application. +int OnExtensionChunkReceived(nghttp2_session* session, + const nghttp2_frame_hd* hd, const uint8_t* data, + size_t len, void* user_data); + +// Invoked when nghttp2 wants the application to unpack an extension payload. +int OnUnpackExtensionCallback(nghttp2_session* session, void** payload, + const nghttp2_frame_hd* hd, void* user_data); + +// Invoked when nghttp2 is ready to pack an extension payload. Returns the +// number of bytes serialized to |buf|. +ssize_t OnPackExtensionCallback(nghttp2_session* session, uint8_t* buf, + size_t len, const nghttp2_frame* frame, + void* user_data); + +// Invoked when the library has an error message to deliver. +int OnError(nghttp2_session* session, int lib_error_code, const char* msg, + size_t len, void* user_data); -nghttp2_session_callbacks* Create(); +nghttp2_session_callbacks_unique_ptr Create(); } // namespace callbacks } // namespace adapter diff --git a/gquiche/http2/adapter/nghttp2_data_provider.cc b/gquiche/http2/adapter/nghttp2_data_provider.cc new file mode 100644 index 00000000..69cba14e --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_data_provider.cc @@ -0,0 +1,63 @@ +#include "gquiche/http2/adapter/nghttp2_data_provider.h" + +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/adapter/nghttp2_util.h" + +namespace http2 { +namespace adapter { +namespace callbacks { + +namespace { +const size_t kFrameHeaderSize = 9; +} + +ssize_t DataFrameSourceReadCallback(nghttp2_session* /* session */, + int32_t /* stream_id */, + uint8_t* /* buf */, + size_t length, + uint32_t* data_flags, + nghttp2_data_source* source, + void* /* user_data */) { + *data_flags |= NGHTTP2_DATA_FLAG_NO_COPY; + auto* frame_source = static_cast(source->ptr); + auto [result_length, done] = frame_source->SelectPayloadLength(length); + if (result_length == 0 && !done) { + return NGHTTP2_ERR_DEFERRED; + } else if (result_length == DataFrameSource::kError) { + return NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; + } + if (done) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + if (!frame_source->send_fin()) { + *data_flags |= NGHTTP2_DATA_FLAG_NO_END_STREAM; + } + return result_length; +} + +int DataFrameSourceSendCallback(nghttp2_session* /* session */, + nghttp2_frame* /* frame */, + const uint8_t* framehd, + size_t length, + nghttp2_data_source* source, + void* /* user_data */) { + auto* frame_source = static_cast(source->ptr); + frame_source->Send(ToStringView(framehd, kFrameHeaderSize), length); + return 0; +} + +} // namespace callbacks + +std::unique_ptr MakeDataProvider( + DataFrameSource* source) { + if (source == nullptr) { + return nullptr; + } + auto provider = absl::make_unique(); + provider->source.ptr = source; + provider->read_callback = &callbacks::DataFrameSourceReadCallback; + return provider; +} + +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_data_provider.h b/gquiche/http2/adapter/nghttp2_data_provider.h new file mode 100644 index 00000000..8e9b847c --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_data_provider.h @@ -0,0 +1,40 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_DATA_PROVIDER_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_DATA_PROVIDER_H_ + +#include +#include + +#include "gquiche/http2/adapter/data_source.h" +#include "gquiche/http2/adapter/nghttp2.h" + +namespace http2 { +namespace adapter { +namespace callbacks { + +// Assumes |source| is a DataFrameSource. +ssize_t DataFrameSourceReadCallback(nghttp2_session* /*session */, + int32_t /* stream_id */, + uint8_t* /* buf */, + size_t length, + uint32_t* data_flags, + nghttp2_data_source* source, + void* /* user_data */); + +int DataFrameSourceSendCallback(nghttp2_session* /* session */, + nghttp2_frame* /* frame */, + const uint8_t* framehd, + size_t length, + nghttp2_data_source* source, + void* /* user_data */); + +} // namespace callbacks + +// Transforms a DataFrameSource into a nghttp2_data_provider. Does not take +// ownership of |source|. Returns nullptr if |source| is nullptr. +std::unique_ptr MakeDataProvider( + DataFrameSource* source); + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_DATA_PROVIDER_H_ diff --git a/gquiche/http2/adapter/nghttp2_data_provider_test.cc b/gquiche/http2/adapter/nghttp2_data_provider_test.cc new file mode 100644 index 00000000..41e5ff7d --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_data_provider_test.cc @@ -0,0 +1,117 @@ +#include "gquiche/http2/adapter/nghttp2_data_provider.h" + +#include "gquiche/http2/adapter/test_utils.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +const size_t kFrameHeaderSize = 9; + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the amount of data read is less +// than what the source provides. +TEST(DataProviderTest, ReadLessThanSourceProvides) { + DataSavingVisitor visitor; + TestDataFrameSource source(visitor, true); + source.AppendPayload("Example payload"); + source.EndData(); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 10; + // Read callback selects a payload length given an upper bound. + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + ASSERT_EQ(kReadLength, result); + EXPECT_EQ(NGHTTP2_DATA_FLAG_NO_COPY, data_flags); + + const uint8_t framehd[kFrameHeaderSize] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + // Sends the frame header and some payload bytes. + int send_result = callbacks::DataFrameSourceSendCallback( + nullptr, nullptr, framehd, result, &provider->source, nullptr); + EXPECT_EQ(0, send_result); + // Data accepted by the visitor includes a frame header and kReadLength bytes + // of payload. + EXPECT_EQ(visitor.data().size(), kFrameHeaderSize + kReadLength); +} + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the amount of data read is more +// than what the source provides. +TEST(DataProviderTest, ReadMoreThanSourceProvides) { + DataSavingVisitor visitor; + const absl::string_view kPayload = "Example payload"; + TestDataFrameSource source(visitor, true); + source.AppendPayload(kPayload); + source.EndData(); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 30; + // Read callback selects a payload length given an upper bound. + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + ASSERT_EQ(kPayload.size(), result); + EXPECT_EQ(NGHTTP2_DATA_FLAG_NO_COPY | NGHTTP2_DATA_FLAG_EOF, data_flags); + + const uint8_t framehd[kFrameHeaderSize] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + // Sends the frame header and some payload bytes. + int send_result = callbacks::DataFrameSourceSendCallback( + nullptr, nullptr, framehd, result, &provider->source, nullptr); + EXPECT_EQ(0, send_result); + // Data accepted by the visitor includes a frame header and the entire + // payload. + EXPECT_EQ(visitor.data().size(), kFrameHeaderSize + kPayload.size()); +} + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the source is blocked. +TEST(DataProviderTest, ReadFromBlockedSource) { + DataSavingVisitor visitor; + // Source has no payload, but also no fin, so it's blocked. + TestDataFrameSource source(visitor, false); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 10; + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + // Read operation is deferred, since the source is blocked. + EXPECT_EQ(NGHTTP2_ERR_DEFERRED, result); +} + +// Verifies that a nghttp2_data_provider derived from a DataFrameSource works +// correctly with nghttp2-style callbacks when the source provides only fin and +// no data. +TEST(DataProviderTest, ReadFromZeroLengthSource) { + DataSavingVisitor visitor; + // Empty payload and fin=true indicates the source is done. + TestDataFrameSource source(visitor, true); + source.EndData(); + auto provider = MakeDataProvider(&source); + uint32_t data_flags = 0; + const int32_t kStreamId = 1; + const size_t kReadLength = 10; + ssize_t result = + provider->read_callback(nullptr, kStreamId, nullptr, kReadLength, + &data_flags, &provider->source, nullptr); + ASSERT_EQ(0, result); + EXPECT_EQ(NGHTTP2_DATA_FLAG_NO_COPY | NGHTTP2_DATA_FLAG_EOF, data_flags); + + const uint8_t framehd[kFrameHeaderSize] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + int send_result = callbacks::DataFrameSourceSendCallback( + nullptr, nullptr, framehd, result, &provider->source, nullptr); + EXPECT_EQ(0, send_result); + // Data accepted by the visitor includes a frame header with fin and zero + // bytes of payload. + EXPECT_EQ(visitor.data().size(), kFrameHeaderSize); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_session.cc b/gquiche/http2/adapter/nghttp2_session.cc index 1c83b77b..cca14a98 100644 --- a/gquiche/http2/adapter/nghttp2_session.cc +++ b/gquiche/http2/adapter/nghttp2_session.cc @@ -4,38 +4,21 @@ namespace http2 { namespace adapter { -namespace { - -void DeleteSession(nghttp2_session* session) { - nghttp2_session_del(session); -} - -void DeleteOptions(nghttp2_option* options) { - nghttp2_option_del(options); -} - -} // namespace NgHttp2Session::NgHttp2Session(Perspective perspective, - nghttp2_session_callbacks* callbacks, - nghttp2_option* options, - void* userdata) - : session_(nullptr, DeleteSession), - options_(options, DeleteOptions), - perspective_(perspective) { + nghttp2_session_callbacks_unique_ptr callbacks, + const nghttp2_option* options, void* userdata) + : session_(MakeSessionPtr(nullptr)), perspective_(perspective) { nghttp2_session* session; - switch (perspective) { + switch (perspective_) { case Perspective::kClient: - nghttp2_session_client_new2(&session, callbacks, userdata, - options_.get()); + nghttp2_session_client_new2(&session, callbacks.get(), userdata, options); break; case Perspective::kServer: - nghttp2_session_server_new2(&session, callbacks, userdata, - options_.get()); + nghttp2_session_server_new2(&session, callbacks.get(), userdata, options); break; } - nghttp2_session_callbacks_del(callbacks); - session_.reset(session); + session_ = MakeSessionPtr(session); } NgHttp2Session::~NgHttp2Session() { @@ -47,7 +30,7 @@ NgHttp2Session::~NgHttp2Session() { << " or pending writes: " << pending_writes; } -ssize_t NgHttp2Session::ProcessBytes(absl::string_view bytes) { +int64_t NgHttp2Session::ProcessBytes(absl::string_view bytes) { return nghttp2_session_mem_recv( session_.get(), reinterpret_cast(bytes.data()), bytes.size()); diff --git a/gquiche/http2/adapter/nghttp2_session.h b/gquiche/http2/adapter/nghttp2_session.h index d52ecec1..57510d41 100644 --- a/gquiche/http2/adapter/nghttp2_session.h +++ b/gquiche/http2/adapter/nghttp2_session.h @@ -1,22 +1,26 @@ #ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_SESSION_H_ #define QUICHE_HTTP2_ADAPTER_NGHTTP2_SESSION_H_ +#include + #include "gquiche/http2/adapter/http2_session.h" -#include "third_party/nghttp2/src/lib/includes/nghttp2/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2.h" +#include "gquiche/http2/adapter/nghttp2_util.h" +#include "gquiche/common/platform/api/quiche_export.h" namespace http2 { namespace adapter { // A C++ wrapper around common nghttp2_session operations. -class NgHttp2Session : public Http2Session { +class QUICHE_EXPORT_PRIVATE NgHttp2Session : public Http2Session { public: + // Does not take ownership of |options|. NgHttp2Session(Perspective perspective, - nghttp2_session_callbacks* callbacks, - nghttp2_option* options, - void* userdata); + nghttp2_session_callbacks_unique_ptr callbacks, + const nghttp2_option* options, void* userdata); ~NgHttp2Session() override; - ssize_t ProcessBytes(absl::string_view bytes) override; + int64_t ProcessBytes(absl::string_view bytes) override; int Consume(Http2StreamId stream_id, size_t num_bytes) override; @@ -27,11 +31,7 @@ class NgHttp2Session : public Http2Session { nghttp2_session* raw_ptr() const { return session_.get(); } private: - using SessionDeleter = void (&)(nghttp2_session*); - using OptionsDeleter = void (&)(nghttp2_option*); - - std::unique_ptr session_; - std::unique_ptr options_; + nghttp2_session_unique_ptr session_; Perspective perspective_; }; diff --git a/gquiche/http2/adapter/nghttp2_session_test.cc b/gquiche/http2/adapter/nghttp2_session_test.cc index 4b246656..2afcfb9d 100644 --- a/gquiche/http2/adapter/nghttp2_session_test.cc +++ b/gquiche/http2/adapter/nghttp2_session_test.cc @@ -4,63 +4,59 @@ #include "gquiche/http2/adapter/nghttp2_callbacks.h" #include "gquiche/http2/adapter/nghttp2_util.h" #include "gquiche/http2/adapter/test_frame_sequence.h" +#include "gquiche/http2/adapter/test_utils.h" #include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/common/platform/api/quiche_test_helpers.h" namespace http2 { namespace adapter { namespace test { namespace { -class DataSavingVisitor : public testing::StrictMock { - public: - void Save(absl::string_view data) { absl::StrAppend(&data_, data); } - - const std::string& data() { return data_; } - - private: - std::string data_; +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, }; -ssize_t SaveSessionOutput(nghttp2_session* session, - const uint8_t* data, - size_t length, - int flags, - void* user_data) { - auto visitor = static_cast(user_data); - visitor->Save(ToStringView(data, length)); - return length; -} - class NgHttp2SessionTest : public testing::Test { public: - nghttp2_option* CreateOptions() { - nghttp2_option* options; - nghttp2_option_new(&options); - nghttp2_option_set_no_auto_window_update(options, 1); - return options; + void SetUp() override { + nghttp2_option_new(&options_); + nghttp2_option_set_no_auto_window_update(options_, 1); } - nghttp2_session_callbacks* CreateCallbacks() { - nghttp2_session_callbacks* callbacks = callbacks::Create(); - nghttp2_session_callbacks_set_send_callback(callbacks, &SaveSessionOutput); + void TearDown() override { nghttp2_option_del(options_); } + + nghttp2_session_callbacks_unique_ptr CreateCallbacks() { + nghttp2_session_callbacks_unique_ptr callbacks = callbacks::Create(); return callbacks; } DataSavingVisitor visitor_; + nghttp2_option* options_ = nullptr; }; TEST_F(NgHttp2SessionTest, ClientConstruction) { - NgHttp2Session session(Perspective::kClient, CreateCallbacks(), - CreateOptions(), &visitor_); + NgHttp2Session session(Perspective::kClient, CreateCallbacks(), options_, + &visitor_); EXPECT_TRUE(session.want_read()); EXPECT_FALSE(session.want_write()); - EXPECT_EQ(session.GetRemoteWindowSize(), kDefaultInitialStreamWindowSize); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); EXPECT_NE(session.raw_ptr(), nullptr); } TEST_F(NgHttp2SessionTest, ClientHandlesFrames) { - NgHttp2Session session(Perspective::kClient, CreateCallbacks(), - CreateOptions(), &visitor_); + NgHttp2Session session(Perspective::kClient, CreateCallbacks(), options_, + &visitor_); ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); ASSERT_GT(visitor_.data().size(), 0); @@ -73,36 +69,56 @@ TEST_F(NgHttp2SessionTest, ClientHandlesFrames) { testing::InSequence s; // Server preface (empty SETTINGS) + EXPECT_CALL(visitor_, OnFrameHeader(0, 0, SETTINGS, 0)); EXPECT_CALL(visitor_, OnSettingsStart()); EXPECT_CALL(visitor_, OnSettingsEnd()); + EXPECT_CALL(visitor_, OnFrameHeader(0, 8, PING, 0)); EXPECT_CALL(visitor_, OnPing(42, false)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); EXPECT_CALL(visitor_, OnWindowUpdate(0, 1000)); - const ssize_t initial_result = session.ProcessBytes(initial_frames); + const int64_t initial_result = session.ProcessBytes(initial_frames); EXPECT_EQ(initial_frames.size(), initial_result); EXPECT_EQ(session.GetRemoteWindowSize(), - kDefaultInitialStreamWindowSize + 1000); - ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); - - const std::vector
headers1 = {{":method", "GET"}, - {":scheme", "http"}, - {":authority", "example.com"}, - {":path", "/this/is/request/one"}}; - const auto nvs1 = GetRequestNghttp2Nvs(headers1); + kInitialFlowControlWindowSize + 1000); - const std::vector
headers2 = {{":method", "GET"}, - {":scheme", "http"}, - {":authority", "example.com"}, - {":path", "/this/is/request/two"}}; - const auto nvs2 = GetRequestNghttp2Nvs(headers2); + EXPECT_CALL(visitor_, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(PING, 0, 8, 0x1, 0)); - const std::vector
headers3 = {{":method", "GET"}, - {":scheme", "http"}, - {":authority", "example.com"}, - {":path", "/this/is/request/three"}}; - const auto nvs3 = GetRequestNghttp2Nvs(headers3); + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + // Some bytes should have been serialized. + absl::string_view serialized = visitor_.data(); + ASSERT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING})); + visitor_.Clear(); + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const auto nvs1 = GetNghttp2Nvs(headers1); + + const std::vector headers2 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}); + const auto nvs2 = GetNghttp2Nvs(headers2); + + const std::vector headers3 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/three"}}); + const auto nvs3 = GetNghttp2Nvs(headers3); const int32_t stream_id1 = nghttp2_submit_request( session.raw_ptr(), nullptr, nvs1.data(), nvs1.size(), nullptr, nullptr); @@ -119,7 +135,19 @@ TEST_F(NgHttp2SessionTest, ClientHandlesFrames) { ASSERT_GT(stream_id3, 0); QUICHE_LOG(INFO) << "Created stream: " << stream_id3; + EXPECT_CALL(visitor_, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor_, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(HEADERS, 3, _, 0x5)); + EXPECT_CALL(visitor_, OnFrameSent(HEADERS, 3, _, 0x5, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(HEADERS, 5, _, 0x5)); + EXPECT_CALL(visitor_, OnFrameSent(HEADERS, 5, _, 0x5, 0)); + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + serialized = visitor_.data(); + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::HEADERS})); + visitor_.Clear(); const std::string stream_frames = TestFrameSequence() @@ -133,35 +161,62 @@ TEST_F(NgHttp2SessionTest, ClientHandlesFrames) { .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") .Serialize(); + EXPECT_CALL(visitor_, OnFrameHeader(1, _, HEADERS, 4)); EXPECT_CALL(visitor_, OnBeginHeadersForStream(1)); EXPECT_CALL(visitor_, OnHeaderForStream(1, ":status", "200")); EXPECT_CALL(visitor_, OnHeaderForStream(1, "server", "my-fake-server")); EXPECT_CALL(visitor_, OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); EXPECT_CALL(visitor_, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor_, OnFrameHeader(1, 26, DATA, 0)); EXPECT_CALL(visitor_, OnBeginDataForStream(1, 26)); EXPECT_CALL(visitor_, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor_, OnFrameHeader(3, 4, RST_STREAM, 0)); EXPECT_CALL(visitor_, OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR)); - EXPECT_CALL(visitor_, OnAbortStream(3, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor_, OnCloseStream(3, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 19, GOAWAY, 0)); EXPECT_CALL(visitor_, OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!")); - const ssize_t stream_result = session.ProcessBytes(stream_frames); + const int64_t stream_result = session.ProcessBytes(stream_frames); EXPECT_EQ(stream_frames.size(), stream_result); + + // Even though the client recieved a GOAWAY, streams 1 and 5 are still active. + EXPECT_TRUE(session.want_read()); + + EXPECT_CALL(visitor_, OnFrameHeader(1, 0, DATA, 1)); + EXPECT_CALL(visitor_, OnBeginDataForStream(1, 0)); + EXPECT_CALL(visitor_, OnEndStream(1)); + EXPECT_CALL(visitor_, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor_, OnFrameHeader(5, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor_, OnRstStream(5, Http2ErrorCode::REFUSED_STREAM)); + EXPECT_CALL(visitor_, OnCloseStream(5, Http2ErrorCode::REFUSED_STREAM)); + session.ProcessBytes(TestFrameSequence() + .Data(1, "", true) + .RstStream(5, Http2ErrorCode::REFUSED_STREAM) + .Serialize()); + // After receiving END_STREAM for 1 and RST_STREAM for 5, the session no + // longer expects reads. + EXPECT_FALSE(session.want_read()); + + // Client will not have anything else to write. + EXPECT_FALSE(session.want_write()); ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + serialized = visitor_.data(); + EXPECT_EQ(serialized.size(), 0); } TEST_F(NgHttp2SessionTest, ServerConstruction) { - NgHttp2Session session(Perspective::kServer, CreateCallbacks(), - CreateOptions(), &visitor_); + NgHttp2Session session(Perspective::kServer, CreateCallbacks(), options_, + &visitor_); EXPECT_TRUE(session.want_read()); EXPECT_FALSE(session.want_write()); - EXPECT_EQ(session.GetRemoteWindowSize(), kDefaultInitialStreamWindowSize); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); EXPECT_NE(session.raw_ptr(), nullptr); } TEST_F(NgHttp2SessionTest, ServerHandlesFrames) { - NgHttp2Session session(Perspective::kServer, CreateCallbacks(), - CreateOptions(), &visitor_); + NgHttp2Session session(Perspective::kServer, CreateCallbacks(), options_, + &visitor_); const std::string frames = TestFrameSequence() .ClientPreface() @@ -187,20 +242,27 @@ TEST_F(NgHttp2SessionTest, ServerHandlesFrames) { testing::InSequence s; // Client preface (empty SETTINGS) + EXPECT_CALL(visitor_, OnFrameHeader(0, 0, SETTINGS, 0)); EXPECT_CALL(visitor_, OnSettingsStart()); EXPECT_CALL(visitor_, OnSettingsEnd()); + EXPECT_CALL(visitor_, OnFrameHeader(0, 8, PING, 0)); EXPECT_CALL(visitor_, OnPing(42, false)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); EXPECT_CALL(visitor_, OnWindowUpdate(0, 1000)); + EXPECT_CALL(visitor_, OnFrameHeader(1, _, HEADERS, 4)); EXPECT_CALL(visitor_, OnBeginHeadersForStream(1)); EXPECT_CALL(visitor_, OnHeaderForStream(1, ":method", "POST")); EXPECT_CALL(visitor_, OnHeaderForStream(1, ":scheme", "https")); EXPECT_CALL(visitor_, OnHeaderForStream(1, ":authority", "example.com")); EXPECT_CALL(visitor_, OnHeaderForStream(1, ":path", "/this/is/request/one")); EXPECT_CALL(visitor_, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor_, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); EXPECT_CALL(visitor_, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor_, OnFrameHeader(1, 25, DATA, 0)); EXPECT_CALL(visitor_, OnBeginDataForStream(1, 25)); EXPECT_CALL(visitor_, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor_, OnFrameHeader(3, _, HEADERS, 5)); EXPECT_CALL(visitor_, OnBeginHeadersForStream(3)); EXPECT_CALL(visitor_, OnHeaderForStream(3, ":method", "GET")); EXPECT_CALL(visitor_, OnHeaderForStream(3, ":scheme", "http")); @@ -208,15 +270,50 @@ TEST_F(NgHttp2SessionTest, ServerHandlesFrames) { EXPECT_CALL(visitor_, OnHeaderForStream(3, ":path", "/this/is/request/two")); EXPECT_CALL(visitor_, OnEndHeadersForStream(3)); EXPECT_CALL(visitor_, OnEndStream(3)); + EXPECT_CALL(visitor_, OnFrameHeader(3, 4, RST_STREAM, 0)); EXPECT_CALL(visitor_, OnRstStream(3, Http2ErrorCode::CANCEL)); - EXPECT_CALL(visitor_, OnAbortStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor_, OnCloseStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor_, OnFrameHeader(0, 8, PING, 0)); EXPECT_CALL(visitor_, OnPing(47, false)); - const ssize_t result = session.ProcessBytes(frames); + const int64_t result = session.ProcessBytes(frames); EXPECT_EQ(frames.size(), result); EXPECT_EQ(session.GetRemoteWindowSize(), - kDefaultInitialStreamWindowSize + 1000); + kInitialFlowControlWindowSize + 1000); + + EXPECT_CALL(visitor_, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(PING, 0, 8, 0x1, 0)); + EXPECT_CALL(visitor_, OnBeforeFrameSent(PING, 0, 8, 0x1)); + EXPECT_CALL(visitor_, OnFrameSent(PING, 0, 8, 0x1, 0)); + + EXPECT_TRUE(session.want_write()); + ASSERT_EQ(0, nghttp2_session_send(session.raw_ptr())); + // Some bytes should have been serialized. + absl::string_view serialized = visitor_.data(); + // SETTINGS ack, two PING acks. + EXPECT_THAT(serialized, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING, + spdy::SpdyFrameType::PING})); +} + +// Verifies that a null payload is caught by the OnPackExtensionCallback +// implementation. +TEST_F(NgHttp2SessionTest, NullPayload) { + NgHttp2Session session(Perspective::kClient, CreateCallbacks(), options_, + &visitor_); + + void* payload = nullptr; + const int result = nghttp2_submit_extension( + session.raw_ptr(), kMetadataFrameType, 0, 1, payload); + ASSERT_EQ(0, result); + EXPECT_TRUE(session.want_write()); + int send_result = -1; + EXPECT_QUICHE_BUG(send_result = nghttp2_session_send(session.raw_ptr()), + "Extension frame payload for stream 1 is null!"); + EXPECT_EQ(NGHTTP2_ERR_CALLBACK_FAILURE, send_result); } } // namespace diff --git a/gquiche/http2/adapter/nghttp2_test.cc b/gquiche/http2/adapter/nghttp2_test.cc new file mode 100644 index 00000000..a03835e6 --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_test.cc @@ -0,0 +1,205 @@ +#include "gquiche/http2/adapter/nghttp2.h" + +#include "absl/strings/str_cat.h" +#include "gquiche/http2/adapter/mock_nghttp2_callbacks.h" +#include "gquiche/http2/adapter/nghttp2_test_utils.h" +#include "gquiche/http2/adapter/nghttp2_util.h" +#include "gquiche/http2/adapter/test_frame_sequence.h" +#include "gquiche/http2/adapter/test_utils.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, +}; + +nghttp2_option* GetOptions() { + nghttp2_option* options; + nghttp2_option_new(&options); + // Set some common options for compatibility. + nghttp2_option_set_no_closed_streams(options, 1); + nghttp2_option_set_no_auto_window_update(options, 1); + nghttp2_option_set_max_send_header_block_length(options, 0x2000000); + nghttp2_option_set_max_outbound_ack(options, 10000); + return options; +} + +class Nghttp2Test : public testing::Test { + public: + Nghttp2Test() : session_(MakeSessionPtr(nullptr)) {} + + void SetUp() override { InitializeSession(); } + + virtual Perspective GetPerspective() = 0; + + void InitializeSession() { + auto nghttp2_callbacks = MockNghttp2Callbacks::GetCallbacks(); + nghttp2_option* options = GetOptions(); + nghttp2_session* ptr; + if (GetPerspective() == Perspective::kClient) { + nghttp2_session_client_new2(&ptr, nghttp2_callbacks.get(), + &mock_callbacks_, options); + } else { + nghttp2_session_server_new2(&ptr, nghttp2_callbacks.get(), + &mock_callbacks_, options); + } + nghttp2_option_del(options); + + // Sets up the Send() callback to append to |serialized_|. + EXPECT_CALL(mock_callbacks_, Send(_, _, _)) + .WillRepeatedly( + [this](const uint8_t* data, size_t length, int /*flags*/) { + absl::StrAppend(&serialized_, ToStringView(data, length)); + return length; + }); + // Sets up the SendData() callback to fetch and append data from a + // TestDataSource. + EXPECT_CALL(mock_callbacks_, SendData(_, _, _, _)) + .WillRepeatedly([this](nghttp2_frame* /*frame*/, const uint8_t* framehd, + size_t length, nghttp2_data_source* source) { + QUICHE_LOG(INFO) << "Appending frame header and " << length + << " bytes of data"; + auto* s = static_cast(source->ptr); + absl::StrAppend(&serialized_, ToStringView(framehd, 9), + s->ReadNext(length)); + return 0; + }); + session_ = MakeSessionPtr(ptr); + } + + testing::StrictMock mock_callbacks_; + nghttp2_session_unique_ptr session_; + std::string serialized_; +}; + +class Nghttp2ClientTest : public Nghttp2Test { + public: + Perspective GetPerspective() override { return Perspective::kClient; } +}; + +// Verifies nghttp2 behavior when acting as a client. +TEST_F(Nghttp2ClientTest, ClientReceivesUnexpectedHeaders) { + const std::string initial_frames = TestFrameSequence() + .ServerPreface() + .Ping(42) + .WindowUpdate(0, 1000) + .Serialize(); + + testing::InSequence seq; + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(0, SETTINGS, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsSettings(testing::IsEmpty()))); + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(0, PING, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsPing(42))); + EXPECT_CALL(mock_callbacks_, + OnBeginFrame(HasFrameHeader(0, WINDOW_UPDATE, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsWindowUpdate(1000))); + + ssize_t result = nghttp2_session_mem_recv( + session_.get(), ToUint8Ptr(initial_frames.data()), initial_frames.size()); + ASSERT_EQ(result, initial_frames.size()); + + const std::string unexpected_stream_frames = + TestFrameSequence() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .RstStream(3, Http2ErrorCode::INTERNAL_ERROR) + .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") + .Serialize(); + + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(1, HEADERS, _))); + EXPECT_CALL(mock_callbacks_, OnInvalidFrameRecv(IsHeaders(1, _, _), _)); + // No events from the DATA, RST_STREAM or GOAWAY. + + nghttp2_session_mem_recv(session_.get(), + ToUint8Ptr(unexpected_stream_frames.data()), + unexpected_stream_frames.size()); +} + +// Tests the request-sending behavior of nghttp2 when acting as a client. +TEST_F(Nghttp2ClientTest, ClientSendsRequest) { + int result = nghttp2_session_send(session_.get()); + ASSERT_EQ(result, 0); + + EXPECT_THAT(serialized_, testing::StrEq(spdy::kHttp2ConnectionHeaderPrefix)); + serialized_.clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(mock_callbacks_, OnBeginFrame(HasFrameHeader(0, SETTINGS, 0))); + EXPECT_CALL(mock_callbacks_, OnFrameRecv(IsSettings(testing::IsEmpty()))); + + ssize_t recv_result = nghttp2_session_mem_recv( + session_.get(), ToUint8Ptr(initial_frames.data()), initial_frames.size()); + EXPECT_EQ(initial_frames.size(), recv_result); + + // Client wants to send a SETTINGS ack. + EXPECT_CALL(mock_callbacks_, BeforeFrameSend(IsSettings(testing::IsEmpty()))); + EXPECT_CALL(mock_callbacks_, OnFrameSend(IsSettings(testing::IsEmpty()))); + EXPECT_TRUE(nghttp2_session_want_write(session_.get())); + result = nghttp2_session_send(session_.get()); + EXPECT_THAT(serialized_, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + serialized_.clear(); + + EXPECT_FALSE(nghttp2_session_want_write(session_.get())); + + // The following sets up the client request. + std::vector> headers = { + {":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}; + std::vector nvs; + for (const auto& h : headers) { + nvs.push_back({.name = ToUint8Ptr(h.first.data()), + .value = ToUint8Ptr(h.second.data()), + .namelen = h.first.size(), + .valuelen = h.second.size()}); + } + const absl::string_view kBody = "This is an example request body."; + TestDataSource source{kBody}; + nghttp2_data_provider provider = source.MakeDataProvider(); + // After submitting the request, the client will want to write. + int stream_id = + nghttp2_submit_request(session_.get(), nullptr /* pri_spec */, nvs.data(), + nvs.size(), &provider, nullptr /* stream_data */); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(nghttp2_session_want_write(session_.get())); + + // We expect that the client will want to write HEADERS, then DATA. + EXPECT_CALL(mock_callbacks_, BeforeFrameSend(IsHeaders(stream_id, _, _))); + EXPECT_CALL(mock_callbacks_, OnFrameSend(IsHeaders(stream_id, _, _))); + EXPECT_CALL(mock_callbacks_, OnFrameSend(IsData(stream_id, kBody.size(), _))); + nghttp2_session_send(session_.get()); + EXPECT_THAT(serialized_, EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(serialized_, testing::HasSubstr(kBody)); + + // Once the request is flushed, the client no longer wants to write. + EXPECT_FALSE(nghttp2_session_want_write(session_.get())); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_test_utils.cc b/gquiche/http2/adapter/nghttp2_test_utils.cc new file mode 100644 index 00000000..127ced80 --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_test_utils.cc @@ -0,0 +1,454 @@ +#include "gquiche/http2/adapter/nghttp2_test_utils.h" + +#include "gquiche/http2/adapter/nghttp2_util.h" +#include "gquiche/common/quiche_endian.h" + +namespace http2 { +namespace adapter { +namespace test { + +namespace { + +// Custom gMock matcher, used to implement HasFrameHeader(). +class FrameHeaderMatcher { + public: + FrameHeaderMatcher(int32_t streamid, uint8_t type, + const testing::Matcher flags) + : stream_id_(streamid), type_(type), flags_(flags) {} + + bool Match(const nghttp2_frame_hd& frame, + testing::MatchResultListener* listener) const { + bool matched = true; + if (stream_id_ != frame.stream_id) { + *listener << "; expected stream " << stream_id_ << ", saw " + << frame.stream_id; + matched = false; + } + if (type_ != frame.type) { + *listener << "; expected frame type " << type_ << ", saw " + << static_cast(frame.type); + matched = false; + } + if (!flags_.MatchAndExplain(frame.flags, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const { + *os << "contains a frame header with stream " << stream_id_ << ", type " + << type_ << ", "; + flags_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "does not contain a frame header with stream " << stream_id_ + << ", type " << type_ << ", "; + flags_.DescribeNegationTo(os); + } + + private: + const int32_t stream_id_; + const int type_; + const testing::Matcher flags_; +}; + +class PointerToFrameHeaderMatcher + : public FrameHeaderMatcher, + public testing::MatcherInterface { + public: + PointerToFrameHeaderMatcher(int32_t streamid, uint8_t type, + const testing::Matcher flags) + : FrameHeaderMatcher(streamid, type, flags) {} + + bool MatchAndExplain(const nghttp2_frame_hd* frame, + testing::MatchResultListener* listener) const override { + return FrameHeaderMatcher::Match(*frame, listener); + } + + void DescribeTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeNegationTo(os); + } +}; + +class ReferenceToFrameHeaderMatcher + : public FrameHeaderMatcher, + public testing::MatcherInterface { + public: + ReferenceToFrameHeaderMatcher(int32_t streamid, uint8_t type, + const testing::Matcher flags) + : FrameHeaderMatcher(streamid, type, flags) {} + + bool MatchAndExplain(const nghttp2_frame_hd& frame, + testing::MatchResultListener* listener) const override { + return FrameHeaderMatcher::Match(frame, listener); + } + + void DescribeTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + FrameHeaderMatcher::DescribeNegationTo(os); + } +}; + +class DataMatcher : public testing::MatcherInterface { + public: + DataMatcher(const testing::Matcher stream_id, + const testing::Matcher length, + const testing::Matcher flags) + : stream_id_(stream_id), length_(length), flags_(flags) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_DATA) { + *listener << "; expected DATA frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!stream_id_.MatchAndExplain(frame->hd.stream_id, listener)) { + matched = false; + } + if (!length_.MatchAndExplain(frame->hd.length, listener)) { + matched = false; + } + if (!flags_.MatchAndExplain(frame->hd.flags, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a DATA frame, "; + stream_id_.DescribeTo(os); + length_.DescribeTo(os); + flags_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a DATA frame, "; + stream_id_.DescribeNegationTo(os); + length_.DescribeNegationTo(os); + flags_.DescribeNegationTo(os); + } + + private: + const testing::Matcher stream_id_; + const testing::Matcher length_; + const testing::Matcher flags_; +}; + +class HeadersMatcher : public testing::MatcherInterface { + public: + HeadersMatcher(const testing::Matcher stream_id, + const testing::Matcher flags, + const testing::Matcher category) + : stream_id_(stream_id), flags_(flags), category_(category) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_HEADERS) { + *listener << "; expected HEADERS frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!stream_id_.MatchAndExplain(frame->hd.stream_id, listener)) { + matched = false; + } + if (!flags_.MatchAndExplain(frame->hd.flags, listener)) { + matched = false; + } + if (!category_.MatchAndExplain(frame->headers.cat, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a HEADERS frame, "; + stream_id_.DescribeTo(os); + flags_.DescribeTo(os); + category_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a HEADERS frame, "; + stream_id_.DescribeNegationTo(os); + flags_.DescribeNegationTo(os); + category_.DescribeNegationTo(os); + } + + private: + const testing::Matcher stream_id_; + const testing::Matcher flags_; + const testing::Matcher category_; +}; + +class RstStreamMatcher + : public testing::MatcherInterface { + public: + RstStreamMatcher(const testing::Matcher stream_id, + const testing::Matcher error_code) + : stream_id_(stream_id), error_code_(error_code) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_RST_STREAM) { + *listener << "; expected RST_STREAM frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!stream_id_.MatchAndExplain(frame->hd.stream_id, listener)) { + matched = false; + } + if (!error_code_.MatchAndExplain(frame->rst_stream.error_code, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a RST_STREAM frame, "; + stream_id_.DescribeTo(os); + error_code_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a RST_STREAM frame, "; + stream_id_.DescribeNegationTo(os); + error_code_.DescribeNegationTo(os); + } + + private: + const testing::Matcher stream_id_; + const testing::Matcher error_code_; +}; + +class SettingsMatcher : public testing::MatcherInterface { + public: + SettingsMatcher(const testing::Matcher> values) + : values_(values) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_SETTINGS) { + *listener << "; expected SETTINGS frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + std::vector settings; + settings.reserve(frame->settings.niv); + for (size_t i = 0; i < frame->settings.niv; ++i) { + const auto& p = frame->settings.iv[i]; + settings.push_back({static_cast(p.settings_id), p.value}); + } + return values_.MatchAndExplain(settings, listener); + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a SETTINGS frame, "; + values_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a SETTINGS frame, "; + values_.DescribeNegationTo(os); + } + + private: + const testing::Matcher> values_; +}; + +class PingMatcher : public testing::MatcherInterface { + public: + PingMatcher(const testing::Matcher id, bool is_ack) + : id_(id), is_ack_(is_ack) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_PING) { + *listener << "; expected PING frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + bool frame_ack = frame->hd.flags & NGHTTP2_FLAG_ACK; + if (is_ack_ != frame_ack) { + *listener << "; expected is_ack=" << is_ack_ << ", saw " << frame_ack; + matched = false; + } + uint64_t data; + std::memcpy(&data, frame->ping.opaque_data, sizeof(data)); + data = quiche::QuicheEndian::HostToNet64(data); + if (!id_.MatchAndExplain(data, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a PING frame, "; + id_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a PING frame, "; + id_.DescribeNegationTo(os); + } + + private: + const testing::Matcher id_; + const bool is_ack_; +}; + +class GoAwayMatcher : public testing::MatcherInterface { + public: + GoAwayMatcher(const testing::Matcher last_stream_id, + const testing::Matcher error_code, + const testing::Matcher opaque_data) + : last_stream_id_(last_stream_id), + error_code_(error_code), + opaque_data_(opaque_data) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_GOAWAY) { + *listener << "; expected GOAWAY frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + bool matched = true; + if (!last_stream_id_.MatchAndExplain(frame->goaway.last_stream_id, + listener)) { + matched = false; + } + if (!error_code_.MatchAndExplain(frame->goaway.error_code, listener)) { + matched = false; + } + auto opaque_data = + ToStringView(frame->goaway.opaque_data, frame->goaway.opaque_data_len); + if (!opaque_data_.MatchAndExplain(opaque_data, listener)) { + matched = false; + } + return matched; + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a GOAWAY frame, "; + last_stream_id_.DescribeTo(os); + error_code_.DescribeTo(os); + opaque_data_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a GOAWAY frame, "; + last_stream_id_.DescribeNegationTo(os); + error_code_.DescribeNegationTo(os); + opaque_data_.DescribeNegationTo(os); + } + + private: + const testing::Matcher last_stream_id_; + const testing::Matcher error_code_; + const testing::Matcher opaque_data_; +}; + +class WindowUpdateMatcher + : public testing::MatcherInterface { + public: + WindowUpdateMatcher(const testing::Matcher delta) : delta_(delta) {} + + bool MatchAndExplain(const nghttp2_frame* frame, + testing::MatchResultListener* listener) const override { + if (frame->hd.type != NGHTTP2_WINDOW_UPDATE) { + *listener << "; expected WINDOW_UPDATE frame, saw frame of type " + << static_cast(frame->hd.type); + return false; + } + return delta_.MatchAndExplain(frame->window_update.window_size_increment, + listener); + } + + void DescribeTo(std::ostream* os) const override { + *os << "contains a WINDOW_UPDATE frame, "; + delta_.DescribeTo(os); + } + + void DescribeNegationTo(std::ostream* os) const override { + *os << "does not contain a WINDOW_UPDATE frame, "; + delta_.DescribeNegationTo(os); + } + + private: + const testing::Matcher delta_; +}; + +} // namespace + +testing::Matcher HasFrameHeader( + uint32_t streamid, uint8_t type, const testing::Matcher flags) { + return MakeMatcher(new PointerToFrameHeaderMatcher(streamid, type, flags)); +} + +testing::Matcher HasFrameHeaderRef( + uint32_t streamid, uint8_t type, const testing::Matcher flags) { + return MakeMatcher(new ReferenceToFrameHeaderMatcher(streamid, type, flags)); +} + +testing::Matcher IsData( + const testing::Matcher stream_id, + const testing::Matcher length, const testing::Matcher flags) { + return MakeMatcher(new DataMatcher(stream_id, length, flags)); +} + +testing::Matcher IsHeaders( + const testing::Matcher stream_id, + const testing::Matcher flags, const testing::Matcher category) { + return MakeMatcher(new HeadersMatcher(stream_id, flags, category)); +} + +testing::Matcher IsRstStream( + const testing::Matcher stream_id, + const testing::Matcher error_code) { + return MakeMatcher(new RstStreamMatcher(stream_id, error_code)); +} + +testing::Matcher IsSettings( + const testing::Matcher> values) { + return MakeMatcher(new SettingsMatcher(values)); +} + +testing::Matcher IsPing( + const testing::Matcher id) { + return MakeMatcher(new PingMatcher(id, false)); +} + +testing::Matcher IsPingAck( + const testing::Matcher id) { + return MakeMatcher(new PingMatcher(id, true)); +} + +testing::Matcher IsGoAway( + const testing::Matcher last_stream_id, + const testing::Matcher error_code, + const testing::Matcher opaque_data) { + return MakeMatcher( + new GoAwayMatcher(last_stream_id, error_code, opaque_data)); +} + +testing::Matcher IsWindowUpdate( + const testing::Matcher delta) { + return MakeMatcher(new WindowUpdateMatcher(delta)); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_test_utils.h b/gquiche/http2/adapter/nghttp2_test_utils.h new file mode 100644 index 00000000..2759938e --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_test_utils.h @@ -0,0 +1,100 @@ +#ifndef QUICHE_HTTP2_ADAPTER_NGHTTP2_TEST_UTILS_H_ +#define QUICHE_HTTP2_ADAPTER_NGHTTP2_TEST_UTILS_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/nghttp2.h" +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// A simple class that can easily be adapted to act as a nghttp2_data_source. +class QUICHE_NO_EXPORT TestDataSource { + public: + explicit TestDataSource(absl::string_view data) : data_(std::string(data)) {} + + absl::string_view ReadNext(size_t size) { + const size_t to_send = std::min(size, remaining_.size()); + auto ret = remaining_.substr(0, to_send); + remaining_.remove_prefix(to_send); + return ret; + } + + size_t SelectPayloadLength(size_t max_length) { + return std::min(max_length, remaining_.size()); + } + + nghttp2_data_provider MakeDataProvider() { + return nghttp2_data_provider{ + .source = {.ptr = this}, + .read_callback = [](nghttp2_session*, int32_t, uint8_t*, size_t length, + uint32_t* data_flags, nghttp2_data_source* source, + void*) -> ssize_t { + *data_flags |= NGHTTP2_DATA_FLAG_NO_COPY; + auto* s = static_cast(source->ptr); + if (!s->is_data_available()) { + return NGHTTP2_ERR_DEFERRED; + } + const ssize_t ret = s->SelectPayloadLength(length); + if (ret < static_cast(length)) { + *data_flags |= NGHTTP2_DATA_FLAG_EOF; + } + return ret; + }}; + } + + bool is_data_available() const { return is_data_available_; } + void set_is_data_available(bool value) { is_data_available_ = value; } + + private: + const std::string data_; + absl::string_view remaining_ = data_; + bool is_data_available_ = true; +}; + +// Matchers for nghttp2 data types. +testing::Matcher HasFrameHeader( + uint32_t streamid, uint8_t type, const testing::Matcher flags); +testing::Matcher HasFrameHeaderRef( + uint32_t streamid, uint8_t type, const testing::Matcher flags); + +testing::Matcher IsData( + const testing::Matcher stream_id, + const testing::Matcher length, const testing::Matcher flags); + +testing::Matcher IsHeaders( + const testing::Matcher stream_id, + const testing::Matcher flags, const testing::Matcher category); + +testing::Matcher IsRstStream( + const testing::Matcher stream_id, + const testing::Matcher error_code); + +testing::Matcher IsSettings( + const testing::Matcher> values); + +testing::Matcher IsPing( + const testing::Matcher id); + +testing::Matcher IsPingAck( + const testing::Matcher id); + +testing::Matcher IsGoAway( + const testing::Matcher last_stream_id, + const testing::Matcher error_code, + const testing::Matcher opaque_data); + +testing::Matcher IsWindowUpdate( + const testing::Matcher delta); + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_NGHTTP2_TEST_UTILS_H_ diff --git a/gquiche/http2/adapter/nghttp2_util.cc b/gquiche/http2/adapter/nghttp2_util.cc index 89634ad4..adf67f7e 100644 --- a/gquiche/http2/adapter/nghttp2_util.cc +++ b/gquiche/http2/adapter/nghttp2_util.cc @@ -2,14 +2,43 @@ #include +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "gquiche/http2/adapter/http2_protocol.h" -#include "third_party/nghttp2/src/lib/includes/nghttp2/nghttp2.h" #include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/common/quiche_endian.h" namespace http2 { namespace adapter { +namespace { + +using InvalidFrameError = Http2VisitorInterface::InvalidFrameError; + +void DeleteCallbacks(nghttp2_session_callbacks* callbacks) { + if (callbacks) { + nghttp2_session_callbacks_del(callbacks); + } +} + +void DeleteSession(nghttp2_session* session) { + if (session) { + nghttp2_session_del(session); + } +} + +} // namespace + +nghttp2_session_callbacks_unique_ptr MakeCallbacksPtr( + nghttp2_session_callbacks* callbacks) { + return nghttp2_session_callbacks_unique_ptr(callbacks, &DeleteCallbacks); +} + +nghttp2_session_unique_ptr MakeSessionPtr(nghttp2_session* session) { + return nghttp2_session_unique_ptr(session, &DeleteSession); +} + uint8_t* ToUint8Ptr(char* str) { return reinterpret_cast(str); } uint8_t* ToUint8Ptr(const char* str) { return const_cast(reinterpret_cast(str)); @@ -29,16 +58,27 @@ absl::string_view ToStringView(const uint8_t* pointer, size_t length) { return absl::string_view(reinterpret_cast(pointer), length); } -std::vector GetRequestNghttp2Nvs(absl::Span headers) { +std::vector GetNghttp2Nvs(absl::Span headers) { const int num_headers = headers.size(); - auto nghttp2_nvs = std::vector(num_headers); + std::vector nghttp2_nvs; + nghttp2_nvs.reserve(num_headers); for (int i = 0; i < num_headers; ++i) { nghttp2_nv header; - header.name = ToUint8Ptr(&headers[i].first[0]); - header.namelen = headers[i].first.size(); - header.value = ToUint8Ptr(&headers[i].second[0]); - header.valuelen = headers[i].second.size(); - header.flags = NGHTTP2_FLAG_NONE; + uint8_t flags = NGHTTP2_NV_FLAG_NONE; + + const auto [name, no_copy_name] = GetStringView(headers[i].first); + header.name = ToUint8Ptr(name.data()); + header.namelen = name.size(); + if (no_copy_name) { + flags |= NGHTTP2_NV_FLAG_NO_COPY_NAME; + } + const auto [value, no_copy_value] = GetStringView(headers[i].second); + header.value = ToUint8Ptr(value.data()); + header.valuelen = value.size(); + if (no_copy_value) { + flags |= NGHTTP2_NV_FLAG_NO_COPY_VALUE; + } + header.flags = flags; nghttp2_nvs.push_back(std::move(header)); } @@ -50,7 +90,8 @@ std::vector GetResponseNghttp2Nvs( absl::string_view response_code) { // Allocate enough for all headers and also the :status pseudoheader. const int num_headers = headers.size(); - auto nghttp2_nvs = std::vector(num_headers + 1); + std::vector nghttp2_nvs; + nghttp2_nvs.reserve(num_headers + 1); // Add the :status pseudoheader first. nghttp2_nv status; @@ -62,7 +103,7 @@ std::vector GetResponseNghttp2Nvs( nghttp2_nvs.push_back(std::move(status)); // Add the remaining headers. - for (const auto header_pair : headers) { + for (const auto& header_pair : headers) { nghttp2_nv header; header.name = ToUint8Ptr(header_pair.first.data()); header.namelen = header_pair.first.size(); @@ -82,5 +123,183 @@ Http2ErrorCode ToHttp2ErrorCode(uint32_t wire_error_code) { return static_cast(wire_error_code); } +int ToNgHttp2ErrorCode(InvalidFrameError error) { + switch (error) { + case InvalidFrameError::kProtocol: + return NGHTTP2_ERR_PROTO; + case InvalidFrameError::kRefusedStream: + return NGHTTP2_ERR_REFUSED_STREAM; + case InvalidFrameError::kHttpHeader: + return NGHTTP2_ERR_HTTP_HEADER; + case InvalidFrameError::kHttpMessaging: + return NGHTTP2_ERR_HTTP_MESSAGING; + case InvalidFrameError::kFlowControl: + return NGHTTP2_ERR_FLOW_CONTROL; + case InvalidFrameError::kStreamClosed: + return NGHTTP2_ERR_STREAM_CLOSED; + } + return NGHTTP2_ERR_PROTO; +} + +InvalidFrameError ToInvalidFrameError(int error) { + switch (error) { + case NGHTTP2_ERR_PROTO: + return InvalidFrameError::kProtocol; + case NGHTTP2_ERR_REFUSED_STREAM: + return InvalidFrameError::kRefusedStream; + case NGHTTP2_ERR_HTTP_HEADER: + return InvalidFrameError::kHttpHeader; + case NGHTTP2_ERR_HTTP_MESSAGING: + return InvalidFrameError::kHttpMessaging; + case NGHTTP2_ERR_FLOW_CONTROL: + return InvalidFrameError::kFlowControl; + case NGHTTP2_ERR_STREAM_CLOSED: + return InvalidFrameError::kStreamClosed; + } + return InvalidFrameError::kProtocol; +} + +class Nghttp2DataFrameSource : public DataFrameSource { + public: + Nghttp2DataFrameSource(nghttp2_data_provider provider, + nghttp2_send_data_callback send_data, + void* user_data) + : provider_(std::move(provider)), + send_data_(std::move(send_data)), + user_data_(user_data) {} + + std::pair SelectPayloadLength(size_t max_length) override { + const int32_t stream_id = 0; + uint32_t data_flags = 0; + int64_t result = provider_.read_callback( + nullptr /* session */, stream_id, nullptr /* buf */, max_length, + &data_flags, &provider_.source, nullptr /* user_data */); + if (result == NGHTTP2_ERR_DEFERRED) { + return {kBlocked, false}; + } else if (result < 0) { + return {kError, false}; + } else if ((data_flags & NGHTTP2_DATA_FLAG_NO_COPY) == 0) { + QUICHE_LOG(ERROR) << "Source did not use the zero-copy API!"; + return {kError, false}; + } else { + const bool eof = data_flags & NGHTTP2_DATA_FLAG_EOF; + if (eof && (data_flags & NGHTTP2_DATA_FLAG_NO_END_STREAM) == 0) { + send_fin_ = true; + } + return {result, eof}; + } + } + + bool Send(absl::string_view frame_header, size_t payload_length) override { + nghttp2_frame frame; + frame.hd.type = 0; + frame.hd.length = payload_length; + frame.hd.flags = 0; + frame.hd.stream_id = 0; + frame.data.padlen = 0; + const int result = send_data_( + nullptr /* session */, &frame, ToUint8Ptr(frame_header.data()), + payload_length, &provider_.source, user_data_); + QUICHE_LOG_IF(ERROR, result < 0 && result != NGHTTP2_ERR_WOULDBLOCK) + << "Unexpected error code from send: " << result; + return result == 0; + } + + bool send_fin() const override { return send_fin_; } + + private: + nghttp2_data_provider provider_; + nghttp2_send_data_callback send_data_; + void* user_data_; + bool send_fin_ = false; +}; + +std::unique_ptr MakeZeroCopyDataFrameSource( + nghttp2_data_provider provider, + void* user_data, + nghttp2_send_data_callback send_data) { + return absl::make_unique( + std::move(provider), std::move(send_data), user_data); +} + +absl::string_view ErrorString(uint32_t error_code) { + return Http2ErrorCodeToString(static_cast(error_code)); +} + +size_t PaddingLength(uint8_t flags, size_t padlen) { + return (flags & 0x8 ? 1 : 0) + padlen; +} + +struct NvFormatter { + void operator()(std::string* out, const nghttp2_nv& nv) { + absl::StrAppend(out, ToStringView(nv.name, nv.namelen), ": ", + ToStringView(nv.value, nv.valuelen)); + } +}; + +std::string NvsAsString(nghttp2_nv* nva, size_t nvlen) { + return absl::StrJoin(absl::MakeConstSpan(nva, nvlen), ", ", NvFormatter()); +} + +#define HTTP2_FRAME_SEND_LOG QUICHE_VLOG(1) + +void LogBeforeSend(const nghttp2_frame& frame) { + switch (static_cast(frame.hd.type)) { + case FrameType::DATA: + HTTP2_FRAME_SEND_LOG << "Sending DATA on stream " << frame.hd.stream_id + << " with length " + << frame.hd.length - PaddingLength(frame.hd.flags, + frame.data.padlen) + << " and padding " + << PaddingLength(frame.hd.flags, frame.data.padlen); + break; + case FrameType::HEADERS: + HTTP2_FRAME_SEND_LOG << "Sending HEADERS on stream " << frame.hd.stream_id + << " with headers [" + << NvsAsString(frame.headers.nva, + frame.headers.nvlen) + << "]"; + break; + case FrameType::PRIORITY: + HTTP2_FRAME_SEND_LOG << "Sending PRIORITY"; + break; + case FrameType::RST_STREAM: + HTTP2_FRAME_SEND_LOG << "Sending RST_STREAM on stream " + << frame.hd.stream_id << " with error code " + << ErrorString(frame.rst_stream.error_code); + break; + case FrameType::SETTINGS: + HTTP2_FRAME_SEND_LOG << "Sending SETTINGS with " << frame.settings.niv + << " entries, is_ack: " << (frame.hd.flags & 0x01); + break; + case FrameType::PUSH_PROMISE: + HTTP2_FRAME_SEND_LOG << "Sending PUSH_PROMISE"; + break; + case FrameType::PING: { + Http2PingId ping_id; + std::memcpy(&ping_id, frame.ping.opaque_data, sizeof(Http2PingId)); + HTTP2_FRAME_SEND_LOG << "Sending PING with unique_id " + << quiche::QuicheEndian::NetToHost64(ping_id) + << ", is_ack: " << (frame.hd.flags & 0x01); + break; + } + case FrameType::GOAWAY: + HTTP2_FRAME_SEND_LOG << "Sending GOAWAY with last_stream: " + << frame.goaway.last_stream_id << " and error " + << ErrorString(frame.goaway.error_code); + break; + case FrameType::WINDOW_UPDATE: + HTTP2_FRAME_SEND_LOG << "Sending WINDOW_UPDATE on stream " + << frame.hd.stream_id << " with update delta " + << frame.window_update.window_size_increment; + break; + case FrameType::CONTINUATION: + HTTP2_FRAME_SEND_LOG << "Sending CONTINUATION, which is unexpected"; + break; + } +} + +#undef HTTP2_FRAME_SEND_LOG + } // namespace adapter } // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_util.h b/gquiche/http2/adapter/nghttp2_util.h index 6c2fa291..3e160b75 100644 --- a/gquiche/http2/adapter/nghttp2_util.h +++ b/gquiche/http2/adapter/nghttp2_util.h @@ -8,8 +8,10 @@ #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "gquiche/http2/adapter/data_source.h" #include "gquiche/http2/adapter/http2_protocol.h" -#include "third_party/nghttp2/src/lib/includes/nghttp2/nghttp2.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/adapter/nghttp2.h" #include "gquiche/spdy/core/spdy_header_block.h" namespace http2 { @@ -20,6 +22,18 @@ inline constexpr int kStreamCallbackFailureStatus = NGHTTP2_ERR_TEMPORAL_CALLBACK_FAILURE; inline constexpr int kCancelStatus = NGHTTP2_ERR_CANCEL; +using CallbacksDeleter = void (*)(nghttp2_session_callbacks*); +using SessionDeleter = void (*)(nghttp2_session*); + +using nghttp2_session_callbacks_unique_ptr = + std::unique_ptr; +using nghttp2_session_unique_ptr = + std::unique_ptr; + +nghttp2_session_callbacks_unique_ptr MakeCallbacksPtr( + nghttp2_session_callbacks* callbacks); +nghttp2_session_unique_ptr MakeSessionPtr(nghttp2_session* session); + uint8_t* ToUint8Ptr(char* str); uint8_t* ToUint8Ptr(const char* str); @@ -27,9 +41,9 @@ absl::string_view ToStringView(nghttp2_rcbuf* rc_buffer); absl::string_view ToStringView(uint8_t* pointer, size_t length); absl::string_view ToStringView(const uint8_t* pointer, size_t length); -// Returns the nghttp2 header structure from the given request |headers|, which +// Returns the nghttp2 header structure from the given |headers|, which // must have the correct pseudoheaders preceding other headers. -std::vector GetRequestNghttp2Nvs(absl::Span headers); +std::vector GetNghttp2Nvs(absl::Span headers); // Returns the nghttp2 header structure from the given response |headers|, with // the :status pseudoheader first based on the given |response_code|. The @@ -43,6 +57,21 @@ std::vector GetResponseNghttp2Nvs( // based on the RFC 7540 Section 7 suggestion. Http2ErrorCode ToHttp2ErrorCode(uint32_t wire_error_code); +// Converts between the integer error code used by nghttp2 and the corresponding +// InvalidFrameError value. +int ToNgHttp2ErrorCode(Http2VisitorInterface::InvalidFrameError error); +Http2VisitorInterface::InvalidFrameError ToInvalidFrameError(int error); + +// Transforms a nghttp2_data_provider into a DataFrameSource. Assumes that +// |provider| uses the zero-copy nghttp2_data_source_read_callback API. Unsafe +// otherwise. +std::unique_ptr MakeZeroCopyDataFrameSource( + nghttp2_data_provider provider, + void* user_data, + nghttp2_send_data_callback send_data); + +void LogBeforeSend(const nghttp2_frame& frame); + } // namespace adapter } // namespace http2 diff --git a/gquiche/http2/adapter/nghttp2_util_test.cc b/gquiche/http2/adapter/nghttp2_util_test.cc new file mode 100644 index 00000000..5dad8d5c --- /dev/null +++ b/gquiche/http2/adapter/nghttp2_util_test.cc @@ -0,0 +1,109 @@ +#include "gquiche/http2/adapter/nghttp2_util.h" + +#include "gquiche/http2/adapter/nghttp2_test_utils.h" +#include "gquiche/http2/adapter/test_utils.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +// This send callback assumes |source|'s pointer is a TestDataSource, and +// |user_data| is a std::string. +int FakeSendCallback(nghttp2_session*, nghttp2_frame* /*frame*/, + const uint8_t* framehd, size_t length, + nghttp2_data_source* source, void* user_data) { + auto* dest = static_cast(user_data); + // Appends the frame header to the string. + absl::StrAppend(dest, ToStringView(framehd, 9)); + auto* test_source = static_cast(source->ptr); + absl::string_view payload = test_source->ReadNext(length); + // Appends the frame payload to the string. + absl::StrAppend(dest, payload); + return 0; +} + +TEST(MakeZeroCopyDataFrameSource, EmptyPayload) { + std::string result; + + const absl::string_view kEmptyBody = ""; + TestDataSource body1{kEmptyBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &result, FakeSendCallback); + auto [length, eof] = frame_source->SelectPayloadLength(100); + EXPECT_EQ(length, 0); + EXPECT_TRUE(eof); + frame_source->Send("ninebytes", 0); + EXPECT_EQ(result, "ninebytes"); +} + +TEST(MakeZeroCopyDataFrameSource, ShortPayload) { + std::string result; + + const absl::string_view kShortBody = + "Example Page!" + "
Wow!!" + "
" + ""; + TestDataSource body1{kShortBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &result, FakeSendCallback); + auto [length, eof] = frame_source->SelectPayloadLength(200); + EXPECT_EQ(length, kShortBody.size()); + EXPECT_TRUE(eof); + frame_source->Send("ninebytes", length); + EXPECT_EQ(result, absl::StrCat("ninebytes", kShortBody)); +} + +TEST(MakeZeroCopyDataFrameSource, MultiFramePayload) { + std::string result; + + const absl::string_view kShortBody = + "Example Page!" + "
Wow!!" + "
" + ""; + TestDataSource body1{kShortBody}; + // The TestDataSource is wrapped in the nghttp2_data_provider data type. + nghttp2_data_provider provider = body1.MakeDataProvider(); + + // This call transforms it back into a DataFrameSource, which is compatible + // with the Http2Adapter API. + std::unique_ptr frame_source = + MakeZeroCopyDataFrameSource(provider, &result, FakeSendCallback); + auto ret = frame_source->SelectPayloadLength(50); + EXPECT_EQ(ret.first, 50); + EXPECT_FALSE(ret.second); + frame_source->Send("ninebyte1", ret.first); + + ret = frame_source->SelectPayloadLength(50); + EXPECT_EQ(ret.first, 50); + EXPECT_FALSE(ret.second); + frame_source->Send("ninebyte2", ret.first); + + ret = frame_source->SelectPayloadLength(50); + EXPECT_EQ(ret.first, 44); + EXPECT_TRUE(ret.second); + frame_source->Send("ninebyte3", ret.first); + + EXPECT_EQ(result, + "ninebyte1Example Page!
Wow!!<" + "ninebyte3/th>
"); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/oghttp2_adapter.cc b/gquiche/http2/adapter/oghttp2_adapter.cc index e80d8fed..400275ef 100644 --- a/gquiche/http2/adapter/oghttp2_adapter.cc +++ b/gquiche/http2/adapter/oghttp2_adapter.cc @@ -11,11 +11,9 @@ namespace adapter { namespace { -using spdy::SpdyFrameIR; using spdy::SpdyGoAwayIR; using spdy::SpdyPingIR; using spdy::SpdyPriorityIR; -using spdy::SpdySettingsIR; using spdy::SpdyWindowUpdateIR; } // namespace @@ -30,16 +28,16 @@ std::unique_ptr OgHttp2Adapter::Create( OgHttp2Adapter::~OgHttp2Adapter() {} -ssize_t OgHttp2Adapter::ProcessBytes(absl::string_view bytes) { +bool OgHttp2Adapter::IsServerSession() const { + return session_->IsServerSession(); +} + +int64_t OgHttp2Adapter::ProcessBytes(absl::string_view bytes) { return session_->ProcessBytes(bytes); } void OgHttp2Adapter::SubmitSettings(absl::Span settings) { - auto settings_ir = absl::make_unique(); - for (const Http2Setting& setting : settings) { - settings_ir->AddSetting(setting.id, setting.value); - } - session_->EnqueueFrame(std::move(settings_ir)); + session_->SubmitSettings(settings); } void OgHttp2Adapter::SubmitPriorityForStream(Http2StreamId stream_id, @@ -54,6 +52,10 @@ void OgHttp2Adapter::SubmitPing(Http2PingId ping_id) { session_->EnqueueFrame(absl::make_unique(ping_id)); } +void OgHttp2Adapter::SubmitShutdownNotice() { + session_->StartGracefulShutdown(); +} + void OgHttp2Adapter::SubmitGoAway(Http2StreamId last_accepted_stream_id, Http2ErrorCode error_code, absl::string_view opaque_data) { @@ -67,18 +69,56 @@ void OgHttp2Adapter::SubmitWindowUpdate(Http2StreamId stream_id, absl::make_unique(stream_id, window_increment)); } -void OgHttp2Adapter::SubmitMetadata(Http2StreamId stream_id, bool fin) { - QUICHE_BUG(oghttp2_submit_metadata) << "Not implemented"; +void OgHttp2Adapter::SubmitMetadata(Http2StreamId stream_id, + size_t /* max_frame_size */, + std::unique_ptr source) { + // Not necessary to pass max_frame_size along, since OgHttp2Session tracks the + // peer's advertised max frame size. + session_->SubmitMetadata(stream_id, std::move(source)); } -std::string OgHttp2Adapter::GetBytesToWrite(absl::optional max_bytes) { - return session_->GetBytesToWrite(max_bytes); -} +int OgHttp2Adapter::Send() { return session_->Send(); } -int OgHttp2Adapter::GetPeerConnectionWindow() const { +int OgHttp2Adapter::GetSendWindowSize() const { return session_->GetRemoteWindowSize(); } +int OgHttp2Adapter::GetStreamSendWindowSize(Http2StreamId stream_id) const { + return session_->GetStreamSendWindowSize(stream_id); +} + +int OgHttp2Adapter::GetStreamReceiveWindowLimit(Http2StreamId stream_id) const { + return session_->GetStreamReceiveWindowLimit(stream_id); +} + +int OgHttp2Adapter::GetStreamReceiveWindowSize(Http2StreamId stream_id) const { + return session_->GetStreamReceiveWindowSize(stream_id); +} + +int OgHttp2Adapter::GetReceiveWindowSize() const { + return session_->GetReceiveWindowSize(); +} + +int OgHttp2Adapter::GetHpackEncoderDynamicTableSize() const { + return session_->GetHpackEncoderDynamicTableSize(); +} + +int OgHttp2Adapter::GetHpackEncoderDynamicTableCapacity() const { + return session_->GetHpackEncoderDynamicTableCapacity(); +} + +int OgHttp2Adapter::GetHpackDecoderDynamicTableSize() const { + return session_->GetHpackDecoderDynamicTableSize(); +} + +int OgHttp2Adapter::GetHpackDecoderSizeLimit() const { + return session_->GetHpackDecoderSizeLimit(); +} + +Http2StreamId OgHttp2Adapter::GetHighestReceivedStreamId() const { + return session_->GetHighestReceivedStreamId(); +} + void OgHttp2Adapter::MarkDataConsumedForStream(Http2StreamId stream_id, size_t num_bytes) { session_->Consume(stream_id, num_bytes); @@ -90,8 +130,34 @@ void OgHttp2Adapter::SubmitRst(Http2StreamId stream_id, stream_id, TranslateErrorCode(error_code))); } -const Http2Session& OgHttp2Adapter::session() const { - return *session_; +int32_t OgHttp2Adapter::SubmitRequest( + absl::Span headers, + std::unique_ptr data_source, void* user_data) { + return session_->SubmitRequest(headers, std::move(data_source), user_data); +} + +int OgHttp2Adapter::SubmitResponse( + Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) { + return session_->SubmitResponse(stream_id, headers, std::move(data_source)); +} + +int OgHttp2Adapter::SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) { + return session_->SubmitTrailer(stream_id, trailers); +} + +void OgHttp2Adapter::SetStreamUserData(Http2StreamId stream_id, + void* user_data) { + session_->SetStreamUserData(stream_id, user_data); +} + +void* OgHttp2Adapter::GetStreamUserData(Http2StreamId stream_id) { + return session_->GetStreamUserData(stream_id); +} + +bool OgHttp2Adapter::ResumeStream(Http2StreamId stream_id) { + return session_->ResumeStream(stream_id); } OgHttp2Adapter::OgHttp2Adapter(Http2VisitorInterface& visitor, Options options) diff --git a/gquiche/http2/adapter/oghttp2_adapter.h b/gquiche/http2/adapter/oghttp2_adapter.h index 51770c83..f0847491 100644 --- a/gquiche/http2/adapter/oghttp2_adapter.h +++ b/gquiche/http2/adapter/oghttp2_adapter.h @@ -1,44 +1,70 @@ #ifndef QUICHE_HTTP2_ADAPTER_OGHTTP2_ADAPTER_H_ #define QUICHE_HTTP2_ADAPTER_OGHTTP2_ADAPTER_H_ +#include #include #include "gquiche/http2/adapter/http2_adapter.h" #include "gquiche/http2/adapter/http2_session.h" #include "gquiche/http2/adapter/oghttp2_session.h" +#include "gquiche/common/platform/api/quiche_export.h" namespace http2 { namespace adapter { -class OgHttp2Adapter : public Http2Adapter { +class QUICHE_EXPORT_PRIVATE OgHttp2Adapter : public Http2Adapter { public: using Options = OgHttp2Session::Options; static std::unique_ptr Create(Http2VisitorInterface& visitor, Options options); - ~OgHttp2Adapter(); + ~OgHttp2Adapter() override; // From Http2Adapter. - ssize_t ProcessBytes(absl::string_view bytes) override; + bool IsServerSession() const override; + bool want_read() const override { return session_->want_read(); } + bool want_write() const override { return session_->want_write(); } + int64_t ProcessBytes(absl::string_view bytes) override; void SubmitSettings(absl::Span settings) override; void SubmitPriorityForStream(Http2StreamId stream_id, Http2StreamId parent_stream_id, int weight, bool exclusive) override; void SubmitPing(Http2PingId ping_id) override; + void SubmitShutdownNotice() override; void SubmitGoAway(Http2StreamId last_accepted_stream_id, Http2ErrorCode error_code, absl::string_view opaque_data) override; void SubmitWindowUpdate(Http2StreamId stream_id, int window_increment) override; - void SubmitMetadata(Http2StreamId stream_id, bool fin) override; - std::string GetBytesToWrite(absl::optional max_bytes) override; - int GetPeerConnectionWindow() const override; + void SubmitMetadata(Http2StreamId stream_id, size_t max_frame_size, + std::unique_ptr source) override; + int Send() override; + int GetSendWindowSize() const override; + int GetStreamSendWindowSize(Http2StreamId stream_id) const override; + int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const override; + int GetStreamReceiveWindowSize(Http2StreamId stream_id) const override; + int GetReceiveWindowSize() const override; + int GetHpackEncoderDynamicTableSize() const override; + int GetHpackEncoderDynamicTableCapacity() const; + int GetHpackDecoderDynamicTableSize() const override; + int GetHpackDecoderSizeLimit() const; + Http2StreamId GetHighestReceivedStreamId() const override; void MarkDataConsumedForStream(Http2StreamId stream_id, size_t num_bytes) override; void SubmitRst(Http2StreamId stream_id, Http2ErrorCode error_code) override; + int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data) override; + int SubmitResponse(Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) override; - const Http2Session& session() const; + int SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) override; + + void SetStreamUserData(Http2StreamId stream_id, void* user_data) override; + void* GetStreamUserData(Http2StreamId stream_id) override; + bool ResumeStream(Http2StreamId stream_id) override; private: OgHttp2Adapter(Http2VisitorInterface& visitor, Options options); diff --git a/gquiche/http2/adapter/oghttp2_adapter_test.cc b/gquiche/http2/adapter/oghttp2_adapter_test.cc index 6fdd585a..57e7e20b 100644 --- a/gquiche/http2/adapter/oghttp2_adapter_test.cc +++ b/gquiche/http2/adapter/oghttp2_adapter_test.cc @@ -1,6 +1,12 @@ #include "gquiche/http2/adapter/oghttp2_adapter.h" +#include + +#include "absl/strings/str_join.h" +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" #include "gquiche/http2/adapter/mock_http2_visitor.h" +#include "gquiche/http2/adapter/oghttp2_util.h" #include "gquiche/http2/adapter/test_frame_sequence.h" #include "gquiche/http2/adapter/test_utils.h" #include "gquiche/common/platform/api/quiche_test.h" @@ -11,6 +17,25 @@ namespace adapter { namespace test { namespace { +using ConnectionError = Http2VisitorInterface::ConnectionError; + +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, + CONTINUATION, +}; + +using spdy::SpdyFrameType; + class OgHttp2AdapterTest : public testing::Test { protected: void SetUp() override { @@ -18,73 +43,3144 @@ class OgHttp2AdapterTest : public testing::Test { adapter_ = OgHttp2Adapter::Create(http2_visitor_, options); } - testing::StrictMock http2_visitor_; + DataSavingVisitor http2_visitor_; std::unique_ptr adapter_; }; +TEST_F(OgHttp2AdapterTest, IsServerSession) { + EXPECT_TRUE(adapter_->IsServerSession()); +} + TEST_F(OgHttp2AdapterTest, ProcessBytes) { + testing::InSequence seq; + EXPECT_CALL(http2_visitor_, OnFrameHeader(0, 0, 4, 0)); EXPECT_CALL(http2_visitor_, OnSettingsStart()); EXPECT_CALL(http2_visitor_, OnSettingsEnd()); + EXPECT_CALL(http2_visitor_, OnFrameHeader(0, 8, 6, 0)); EXPECT_CALL(http2_visitor_, OnPing(17, false)); adapter_->ProcessBytes( TestFrameSequence().ClientPreface().Ping(17).Serialize()); } -TEST_F(OgHttp2AdapterTest, SubmitMetadata) { - EXPECT_QUICHE_BUG(adapter_->SubmitMetadata(3, true), "Not implemented"); +TEST_F(OgHttp2AdapterTest, InitialSettings) { + DataSavingVisitor client_visitor; + OgHttp2Adapter::Options client_options{.perspective = Perspective::kClient}; + auto client_adapter = OgHttp2Adapter::Create(client_visitor, client_options); + + DataSavingVisitor server_visitor; + OgHttp2Adapter::Options server_options{.perspective = Perspective::kServer}; + auto server_adapter = OgHttp2Adapter::Create(server_visitor, server_options); + + testing::InSequence s; + + // Client sends the connection preface, including the initial SETTINGS. + EXPECT_CALL(client_visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(client_visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + { + int result = client_adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = client_visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + } + + // Server sends the connection preface, including the initial SETTINGS. + EXPECT_CALL(server_visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0)); + EXPECT_CALL(server_visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0)); + { + int result = server_adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = server_visitor.data(); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + } + + // Client processes the server's initial bytes, including initial SETTINGS. + EXPECT_CALL(client_visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(client_visitor, OnSettingsStart()); + EXPECT_CALL(client_visitor, OnSettingsEnd()); + { + const int64_t result = client_adapter->ProcessBytes(server_visitor.data()); + EXPECT_EQ(server_visitor.data().size(), static_cast(result)); + } + + // Server processes the client's initial bytes, including initial SETTINGS. + EXPECT_CALL(server_visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(server_visitor, OnSettingsStart()); + EXPECT_CALL( + server_visitor, + OnSetting(testing::AllOf( + testing::Field(&Http2Setting::id, Http2KnownSettingsId::ENABLE_PUSH), + testing::Field(&Http2Setting::value, 0)))) + .Times(2); + EXPECT_CALL(server_visitor, OnSettingsEnd()); + { + const int64_t result = server_adapter->ProcessBytes(client_visitor.data()); + EXPECT_EQ(client_visitor.data().size(), static_cast(result)); + } } -TEST_F(OgHttp2AdapterTest, GetPeerConnectionWindow) { - const int peer_window = adapter_->GetPeerConnectionWindow(); - EXPECT_GT(peer_window, 0); +TEST_F(OgHttp2AdapterTest, AutomaticSettingsAndPingAcks) { + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(http2_visitor_, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(http2_visitor_, OnSettingsStart()); + EXPECT_CALL(http2_visitor_, OnSettingsEnd()); + // PING + EXPECT_CALL(http2_visitor_, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(http2_visitor_, OnPing(42, false)); + + const int64_t read_result = adapter_->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter_->want_write()); + + // Server preface (SETTINGS) + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + // SETTINGS ack + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(http2_visitor_, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + // PING ack + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(http2_visitor_, OnFrameSent(PING, 0, _, 0x1, 0)); + + int send_result = adapter_->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT( + http2_visitor_.data(), + EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, spdy::SpdyFrameType::PING})); } -TEST_F(OgHttp2AdapterTest, MarkDataConsumedForStream) { - EXPECT_QUICHE_BUG(adapter_->MarkDataConsumedForStream(1, 11), - "Stream 1 not found"); +TEST_F(OgHttp2AdapterTest, AutomaticPingAcksDisabled) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer, + .auto_ping_ack = false}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = + TestFrameSequence().ClientPreface().Ping(42).Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // PING + EXPECT_CALL(visitor, OnFrameHeader(0, _, PING, 0)); + EXPECT_CALL(visitor, OnPing(42, false)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_EQ(static_cast(read_result), frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + // SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + // No PING ack expected because automatic PING acks are disabled. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS})); } -TEST_F(OgHttp2AdapterTest, TestSerialize) { - EXPECT_TRUE(adapter_->session().want_read()); - EXPECT_FALSE(adapter_->session().want_write()); +TEST(OgHttp2AdapterClientTest, ClientHandles100Headers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); - adapter_->SubmitSettings( - {{HEADER_TABLE_SIZE, 128}, {MAX_FRAME_SIZE, 128 << 10}}); - EXPECT_TRUE(adapter_->session().want_write()); + testing::InSequence s; - adapter_->SubmitPriorityForStream(3, 1, 255, true); - adapter_->SubmitRst(3, Http2ErrorCode::CANCEL); - adapter_->SubmitPing(42); - adapter_->SubmitGoAway(13, Http2ErrorCode::NO_ERROR, ""); - adapter_->SubmitWindowUpdate(3, 127); - EXPECT_TRUE(adapter_->session().want_write()); + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); - EXPECT_THAT( - adapter_->GetBytesToWrite(absl::nullopt), - ContainsFrames( - {spdy::SpdyFrameType::SETTINGS, spdy::SpdyFrameType::PRIORITY, - spdy::SpdyFrameType::RST_STREAM, spdy::SpdyFrameType::PING, - spdy::SpdyFrameType::GOAWAY, spdy::SpdyFrameType::WINDOW_UPDATE})); - EXPECT_FALSE(adapter_->session().want_write()); + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, + /*fin=*/false) + .Ping(101) + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); + EXPECT_CALL(visitor, OnPing(101, false)); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING})); } -TEST_F(OgHttp2AdapterTest, TestPartialSerialize) { - EXPECT_FALSE(adapter_->session().want_write()); +TEST(OgHttp2AdapterClientTest, ClientRejects100HeadersWithFin) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); - adapter_->SubmitSettings( - {{HEADER_TABLE_SIZE, 128}, {MAX_FRAME_SIZE, 128 << 10}}); - adapter_->SubmitGoAway(13, Http2ErrorCode::NO_ERROR, "And don't come back!"); - adapter_->SubmitPing(42); - EXPECT_TRUE(adapter_->session().want_write()); + testing::InSequence s; - const std::string first_part = adapter_->GetBytesToWrite(10); - EXPECT_TRUE(adapter_->session().want_write()); - const std::string second_part = adapter_->GetBytesToWrite(absl::nullopt); - EXPECT_FALSE(adapter_->session().want_write()); - EXPECT_THAT( - absl::StrCat(first_part, second_part), - ContainsFrames({spdy::SpdyFrameType::SETTINGS, - spdy::SpdyFrameType::GOAWAY, spdy::SpdyFrameType::PING})); + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, {{":status", "100"}}, /*fin=*/false) + .Headers(1, {{":status", "100"}}, /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "100")); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + // NOTE: nghttp2 does not deliver the OnEndStream event. + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, _, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterClientTest, ClientHandlesTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{"final-status", "A-OK"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, "final-status", "A-OK")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientHandlesMetadata) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "Example stream metadata") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientHandlesMetadataWithError) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Metadata(0, "Example connection metadata") + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Metadata(1, "Example stream metadata") + .Data(1, "This is the response body.", true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)) + .WillOnce(testing::Return(false)); + // Remaining frames are not processed due to the error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + // Negative integer returned to indicate an error. + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + EXPECT_FALSE(adapter->want_read()); + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientRstStreamWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce(testing::DoAll( + testing::InvokeWithoutArgs([&adapter]() { + adapter->SubmitRst(1, Http2ErrorCode::REFUSED_STREAM); + }), + testing::Return(Http2VisitorInterface::HEADER_RST_STREAM))); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterClientTest, ClientConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientConnectionErrorWhileHandlingHeadersOnly) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")) + .WillOnce( + testing::Return(Http2VisitorInterface::HEADER_CONNECTION_ERROR)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientRejectsHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)) + .WillOnce(testing::Return(false)); + // Rejecting headers leads to a connection error. + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientHandlesSmallerHpackHeaderTableSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"x-i-do-not-like", "green eggs and ham"}, + {"x-i-will-not-eat-them", "here or there, in a box, with a fox"}, + {"x-like-them-in-a-house", "no"}, + {"x-like-them-with-a-mouse", "no"}, + }); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_GT(adapter->GetHpackEncoderDynamicTableSize(), 100); + + const std::string stream_frames = + TestFrameSequence().Settings({{HEADER_TABLE_SIZE, 100u}}).Serialize(); + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 100u})); + // Duplicate setting callback due to the way extensions work. + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 100u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 100); + EXPECT_LE(adapter->GetHpackEncoderDynamicTableSize(), 100); +} + +TEST(OgHttp2AdapterClientTest, ClientHandlesLargerHpackHeaderTableSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 4096); + + const std::string stream_frames = + TestFrameSequence().Settings({{HEADER_TABLE_SIZE, 40960u}}).Serialize(); + // Server preface (SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 40960u})); + // Duplicate setting callback due to the way extensions work. + EXPECT_CALL(visitor, OnSetting(Http2Setting{HEADER_TABLE_SIZE, 40960u})); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + // The increased capacity will not be applied until a SETTINGS ack is + // serialized. + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 4096); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + EXPECT_EQ(adapter->GetHpackEncoderDynamicTableCapacity(), 40960); +} + +TEST(OgHttp2AdapterClientTest, ClientSendsHpackHeaderTableSetting) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + }); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers( + 1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}, + {"x-i-do-not-like", "green eggs and ham"}, + {"x-i-will-not-eat-them", "here or there, in a box, with a fox"}, + {"x-like-them-in-a-house", "no"}, + {"x-like-them-with-a-mouse", "no"}}, + /*fin=*/true) + .Serialize(); + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Server acks client's initial SETTINGS. + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 1)); + EXPECT_CALL(visitor, OnSettingsAck()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(7); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + // Submit settings, check decoder table size. + adapter->SubmitSettings({{HEADER_TABLE_SIZE, 100u}}); + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + // Server preface SETTINGS ack + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + // SETTINGS with the new header table size value + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + + // Because the client has not yet seen an ack from the server for the SETTINGS + // with header table size, it has not applied the new value. + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::vector headers2 = ToHeaders({ + {":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}, + }); + + const int32_t stream_id2 = adapter->SubmitRequest(headers2, nullptr, nullptr); + ASSERT_GT(stream_id2, stream_id1); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id2, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id2, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + visitor.Clear(); + + const std::string response_frames = + TestFrameSequence() + .Headers(stream_id2, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id2, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id2)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id2, _, _)).Times(3); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id2)); + EXPECT_CALL(visitor, OnEndStream(stream_id2)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id2, Http2ErrorCode::HTTP2_NO_ERROR)); + + const int64_t response_result = adapter->ProcessBytes(response_frames); + EXPECT_EQ(response_frames.size(), static_cast(response_result)); + + // Still no ack for the outbound settings. + EXPECT_GT(adapter->GetHpackDecoderSizeLimit(), 100); + + const std::string settings_ack = + TestFrameSequence().SettingsAck().Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 1)); + EXPECT_CALL(visitor, OnSettingsAck()); + + const int64_t ack_result = adapter->ProcessBytes(settings_ack); + EXPECT_EQ(settings_ack.size(), static_cast(ack_result)); + // Ack has finally arrived. + EXPECT_EQ(adapter->GetHpackDecoderSizeLimit(), 100); +} + +// TODO(birenroy): Validate headers and re-enable this test. The library should +// invoke OnErrorDebug() with an error message for the invalid header. The +// library should also invoke OnInvalidFrame() for the invalid HEADERS frame. +TEST(OgHttp2AdapterClientTest, DISABLED_ClientHandlesInvalidTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(1, "This is the response body.") + .Headers(1, {{":bad-status", "9000"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + + // Bad status trailer will cause a PROTOCOL_ERROR. The header is never + // delivered in an OnHeaderForStream callback. + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id1, 4, 0x0, 1)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::PROTOCOL_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterClientTest, ClientFailsOnGoAway) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + + const char* kSentinel1 = "arbitrary pointer 1"; + const int32_t stream_id1 = + adapter->SubmitRequest(headers1, nullptr, const_cast(kSentinel1)); + ASSERT_GT(stream_id1, 0); + QUICHE_LOG(INFO) << "Created stream: " << stream_id1; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .GoAway(1, Http2ErrorCode::INTERNAL_ERROR, "indigestion") + .Data(1, "This is the response body.") + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, + OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, GOAWAY, 0)); + // TODO(birenroy): Pass the GOAWAY opaque data through the oghttp2 stack. + EXPECT_CALL(visitor, OnGoAway(1, Http2ErrorCode::INTERNAL_ERROR, "")) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientRejects101Response) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + const std::vector headers1 = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"upgrade", "new-protocol"}}); + + const int32_t stream_id1 = adapter->SubmitRequest(headers1, nullptr, nullptr); + ASSERT_GT(stream_id1, 0); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id1, _, 0x5, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string stream_frames = + TestFrameSequence() + .ServerPreface() + .Headers(1, + {{":status", "101"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, + OnInvalidFrame( + 1, Http2VisitorInterface::InvalidFrameError::kHttpMessaging)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_frames.size()), stream_result); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL( + visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + EXPECT_TRUE(adapter->want_write()); + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterClientTest, ClientObeysMaxConcurrentStreams) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_FALSE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + // Even though the user has not queued any frames for the session, it should + // still send the connection preface. + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + // Initial SETTINGS. + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence() + .ServerPreface({{.id = MAX_CONCURRENT_STREAMS, .value = 1}}) + .Serialize(); + testing::InSequence s; + + // Server preface (SETTINGS with MAX_CONCURRENT_STREAMS) + EXPECT_CALL(visitor, OnFrameHeader(0, 6, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSetting); + // TODO(diannahu): Remove this duplicate call with a separate + // ExtensionVisitorInterface implementation. + EXPECT_CALL(visitor, OnSetting); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string kBody = "This is an example request body."; + auto body1 = absl::make_unique(visitor, true); + body1->AppendPayload(kBody); + body1->EndData(); + const int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + const int next_stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + + // A new pending stream is created, but because of MAX_CONCURRENT_STREAMS, the + // session should not want to write it at the moment. + EXPECT_GT(next_stream_id, stream_id); + EXPECT_FALSE(adapter->want_write()); + + const std::string stream_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); + EXPECT_CALL(visitor, + OnHeaderForStream(stream_id, "server", "my-fake-server")); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, "date", + "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 26, DATA, 0x1)); + EXPECT_CALL(visitor, OnBeginDataForStream(stream_id, 26)); + EXPECT_CALL(visitor, + OnDataForStream(stream_id, "This is the response body.")); + EXPECT_CALL(visitor, OnEndStream(stream_id)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + // The first stream should close, which should make the session want to write + // the next stream. + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, next_stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, next_stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterClientTest, FailureSendingConnectionPreface) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + visitor.set_has_write_error(); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kSendError)); + + int result = adapter->Send(); + EXPECT_EQ(result, Http2VisitorInterface::kSendError); +} + +TEST(OgHttp2AdapterClientTest, ClientForbidsPushPromise) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::vector push_headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/push"}}); + const std::string frames = TestFrameSequence() + .ServerPreface() + .SettingsAck() + .PushPromise(stream_id, 2, push_headers) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0, though this is not explicitly + // required for OgHttp2: should it be?) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The PUSH_PROMISE is treated as an invalid frame. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, PUSH_PROMISE, _)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidPushPromise)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_LT(read_result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + // SETTINGS ack. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientForbidsPushStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + testing::InSequence s; + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + + visitor.Clear(); + + const std::vector headers = + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}); + const int32_t stream_id = adapter->SubmitRequest(headers, nullptr, nullptr); + ASSERT_GT(stream_id, 0); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + write_result = adapter->Send(); + EXPECT_EQ(0, write_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + const std::string frames = + TestFrameSequence() + .ServerPreface() + .SettingsAck() + .Headers(2, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/true) + .Serialize(); + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + // SETTINGS ack (to acknowledge PUSH_ENABLED=0, though this is not explicitly + // required for OgHttp2: should it be?) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck); + + // The push HEADERS are invalid. + EXPECT_CALL(visitor, OnFrameHeader(2, _, HEADERS, _)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidNewStreamId)); + + const int64_t read_result = adapter->ProcessBytes(frames); + EXPECT_LT(read_result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int result = adapter->Send(); + EXPECT_EQ(0, result); + // SETTINGS ack. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2AdapterClientTest, ClientReceivesDataOnClosedStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kClient}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + int result = adapter->Send(); + EXPECT_EQ(0, result); + absl::string_view data = visitor.data(); + EXPECT_THAT(data, testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + data.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(data, EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Client SETTINGS ack + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client open a stream with a request. + int stream_id = + adapter->SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + + // Let the client RST_STREAM the stream it opened. + adapter->SubmitRst(stream_id, Http2ErrorCode::CANCEL); + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, stream_id, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, stream_id, _, 0x0, + static_cast(Http2ErrorCode::CANCEL))); + EXPECT_CALL(visitor, + OnCloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR)); + + result = adapter->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM})); + visitor.Clear(); + + // Let the server send a response on the stream. (It might not have received + // the RST_STREAM yet.) + const std::string response_frames = + TestFrameSequence() + .Headers(stream_id, + {{":status", "200"}, + {"server", "my-fake-server"}, + {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, + /*fin=*/false) + .Data(stream_id, "This is the response body.", /*fin=*/true) + .Serialize(); + + // The visitor gets notified about the HEADERS frame and DATA frame for the + // closed stream with no further processing on either frame. + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 0x4)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, DATA, 0x1)); + + const int64_t response_result = adapter->ProcessBytes(response_frames); + EXPECT_EQ(response_frames.size(), static_cast(response_result)); + + EXPECT_FALSE(adapter->want_write()); +} + +TEST_F(OgHttp2AdapterTest, SubmitMetadata) { + auto source = absl::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter_->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter_->want_write()); + + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(http2_visitor_, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter_->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + http2_visitor_.data(), + EqualsFrames({spdy::SpdyFrameType::SETTINGS, + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter_->want_write()); +} + +TEST_F(OgHttp2AdapterTest, SubmitMetadataMultipleFrames) { + const auto kLargeValue = std::string(63 * 1024, 'a'); + auto source = absl::make_unique( + ToHeaderBlock(ToHeaders({{"large-value", kLargeValue}}))); + adapter_->SubmitMetadata(1, 16384u, std::move(source)); + EXPECT_TRUE(adapter_->want_write()); + + testing::InSequence seq; + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(kMetadataFrameType, 1, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(kMetadataFrameType, 1, _, 0x4)); + EXPECT_CALL(http2_visitor_, OnFrameSent(kMetadataFrameType, 1, _, 0x4, 0)); + + int result = adapter_->Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = http2_visitor_.data(); + EXPECT_THAT( + serialized, + EqualsFrames({spdy::SpdyFrameType::SETTINGS, + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType), + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter_->want_write()); +} + +TEST_F(OgHttp2AdapterTest, SubmitConnectionMetadata) { + auto source = absl::make_unique(ToHeaderBlock(ToHeaders( + {{"query-cost", "is too darn high"}, {"secret-sauce", "hollandaise"}}))); + adapter_->SubmitMetadata(0, 16384u, std::move(source)); + EXPECT_TRUE(adapter_->want_write()); + + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(kMetadataFrameType, 0, _, 0x4)); + EXPECT_CALL(http2_visitor_, OnFrameSent(kMetadataFrameType, 0, _, 0x4, 0)); + + int result = adapter_->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + http2_visitor_.data(), + EqualsFrames({spdy::SpdyFrameType::SETTINGS, + static_cast(kMetadataFrameType)})); + EXPECT_FALSE(adapter_->want_write()); +} + +TEST_F(OgHttp2AdapterTest, GetSendWindowSize) { + const int peer_window = adapter_->GetSendWindowSize(); + EXPECT_EQ(peer_window, kInitialFlowControlWindowSize); +} + +TEST_F(OgHttp2AdapterTest, MarkDataConsumedForStream) { + EXPECT_QUICHE_BUG(adapter_->MarkDataConsumedForStream(1, 11), + "Stream 1 not found"); +} + +TEST_F(OgHttp2AdapterTest, TestSerialize) { + EXPECT_TRUE(adapter_->want_read()); + EXPECT_FALSE(adapter_->want_write()); + + adapter_->SubmitSettings( + {{HEADER_TABLE_SIZE, 128}, {MAX_FRAME_SIZE, 128 << 10}}); + EXPECT_TRUE(adapter_->want_write()); + + adapter_->SubmitPriorityForStream(3, 1, 255, true); + adapter_->SubmitRst(3, Http2ErrorCode::CANCEL); + adapter_->SubmitPing(42); + adapter_->SubmitGoAway(13, Http2ErrorCode::HTTP2_NO_ERROR, ""); + adapter_->SubmitWindowUpdate(3, 127); + EXPECT_TRUE(adapter_->want_write()); + + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(PRIORITY, 3, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(PRIORITY, 3, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(RST_STREAM, 3, _, 0x0, 0x8)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(PING, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(PING, 0, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(WINDOW_UPDATE, 3, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(WINDOW_UPDATE, 3, _, 0x0, 0)); + + int result = adapter_->Send(); + EXPECT_EQ(0, result); + EXPECT_THAT( + http2_visitor_.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PRIORITY, + SpdyFrameType::RST_STREAM, SpdyFrameType::PING, + SpdyFrameType::GOAWAY, SpdyFrameType::WINDOW_UPDATE})); + EXPECT_FALSE(adapter_->want_write()); +} + +TEST_F(OgHttp2AdapterTest, TestPartialSerialize) { + EXPECT_FALSE(adapter_->want_write()); + + adapter_->SubmitSettings( + {{HEADER_TABLE_SIZE, 128}, {MAX_FRAME_SIZE, 128 << 10}}); + adapter_->SubmitGoAway(13, Http2ErrorCode::HTTP2_NO_ERROR, + "And don't come back!"); + adapter_->SubmitPing(42); + EXPECT_TRUE(adapter_->want_write()); + + http2_visitor_.set_send_limit(20); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + int result = adapter_->Send(); + EXPECT_EQ(0, result); + EXPECT_TRUE(adapter_->want_write()); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + result = adapter_->Send(); + EXPECT_EQ(0, result); + EXPECT_TRUE(adapter_->want_write()); + EXPECT_CALL(http2_visitor_, OnBeforeFrameSent(PING, 0, _, 0x0)); + EXPECT_CALL(http2_visitor_, OnFrameSent(PING, 0, _, 0x0, 0)); + result = adapter_->Send(); + EXPECT_EQ(0, result); + EXPECT_FALSE(adapter_->want_write()); + EXPECT_THAT(http2_visitor_.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY, + SpdyFrameType::PING})); +} + +TEST(OgHttp2AdapterServerTest, ClientSendsContinuation) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true, + /*add_continuation=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 1)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); +} + +TEST(OgHttp2AdapterServerTest, ClientSendsMetadataWithContinuation) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Metadata(0, "Example connection metadata in multiple frames", true) + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false, + /*add_continuation=*/true) + .Metadata(1, + "Some stream metadata that's also sent in multiple frames", + true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Metadata on stream 0 + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnFrameHeader(0, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataForStream(0, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(0)); + + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, CONTINUATION, 4)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + // Metadata on stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 0)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, kMetadataFrameType, 4)); + EXPECT_CALL(visitor, OnBeginMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataForStream(1, _)); + EXPECT_CALL(visitor, OnMetadataEndForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + EXPECT_EQ(TestFrameSequence::MetadataBlockForPayload( + "Example connection metadata in multiple frames"), + absl::StrJoin(visitor.GetMetadata(0), "")); + EXPECT_EQ(TestFrameSequence::MetadataBlockForPayload( + "Some stream metadata that's also sent in multiple frames"), + absl::StrJoin(visitor.GetMetadata(1), "")); +} + +TEST(OgHttp2AdapterServerTest, ServerSubmitsResponseWithDataSourceError) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + auto body1 = absl::make_unique(visitor, false); + body1->SimulateError(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + // TODO(birenroy): Send RST_STREAM INTERNAL_ERROR to the client as well. + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::INTERNAL_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // Since the stream has been closed, it is not possible to submit trailers for + // the stream. + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + ASSERT_LT(trailer_result, 0); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterServerTest, CompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the response body.", /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 1)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, _)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterServerTest, IncompleteRequestWithServerResponse) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + // RST_STREAM NO_ERROR option is disabled. + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterServerTest, + IncompleteRequestWithServerResponseRstStreamEnabled) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer, + .rst_stream_no_error_when_incomplete = true}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + int submit_result = + adapter->SubmitResponse(1, ToHeaders({{":status", "200"}}), nullptr); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(RST_STREAM, 1, 4, 0x0, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::RST_STREAM})); + EXPECT_FALSE(adapter->want_write()); +} + +TEST(OgHttp2AdapterServerTest, ServerSendsInvalidTrailers) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + EXPECT_FALSE(adapter->want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload(kBody); + body1->EndData(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames( + {spdy::SpdyFrameType::SETTINGS, spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::HEADERS, spdy::SpdyFrameType::DATA})); + EXPECT_THAT(visitor.data(), testing::HasSubstr(kBody)); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{":final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); +} + +// Tests the case where the response body is in the progress of being sent while +// trailers are queued. +TEST(OgHttp2AdapterServerTest, ServerSubmitsTrailersWhileDataDeferred) { + DataSavingVisitor visitor; + for (const bool queue_trailers : {true, false}) { + OgHttp2Adapter::Options options{ + .perspective = Perspective::kServer, + .trailers_require_end_data = queue_trailers}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + + const absl::string_view kBody = "This is an example response body."; + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload(kBody); + auto* body1_ptr = body1.get(); + int submit_result = adapter->SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + visitor.Clear(); + EXPECT_FALSE(adapter->want_write()); + + int trailer_result = + adapter->SubmitTrailer(1, ToHeaders({{"final-status", "a-ok"}})); + ASSERT_EQ(trailer_result, 0); + if (queue_trailers) { + // Even though there are new trailers to write, the data source has not + // finished writing data and is blocked. + EXPECT_FALSE(adapter->want_write()); + + body1_ptr->EndData(); + adapter->ResumeStream(1); + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + } else { + // Even though the data source has not finished sending data, the library + // will write the trailers anyway. + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + } + } +} + +TEST(OgHttp2AdapterServerTest, ServerErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"accept", "some bogus value!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnHeaderForStream(1, "accept", "some bogus value!")) + .WillOnce(testing::Return(Http2VisitorInterface::HEADER_RST_STREAM)); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, _)); + EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnWindowUpdate(0, 2000)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM})); +} + +TEST(OgHttp2AdapterServerTest, ServerConnectionErrorWhileHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}, + {"Accept", "uppercase, oh boy!"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL( + visitor, + OnInvalidFrame(1, Http2VisitorInterface::InvalidFrameError::kHttpHeader)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kHeaderError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 1, 4, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 1, 4, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS ack + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, ServerErrorAfterHandlingHeaders) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +// Exercises the case when a visitor chooses to reject a frame based solely on +// the frame header, which is a fatal error for the connection. +TEST(OgHttp2AdapterServerTest, ServerRejectsFrameHeader) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Ping(64) + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .WindowUpdate(1, 2000) + .Data(1, "This is the request body.") + .WindowUpdate(0, 2000) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, ServerRejectsBeginningOfData) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)) + .WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, ServerRejectsStreamData) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(1, "This is the request body.") + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .RstStream(3, Http2ErrorCode::CANCEL) + .Ping(47) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); + EXPECT_CALL(visitor, OnDataForStream(1, _)).WillOnce(testing::Return(false)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kParseError)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::INTERNAL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +// Exercises a naive mutually recursive test client and server. This test fails +// without recursion guards in OgHttp2Session. +TEST(OgHttp2AdapterInteractionTest, ClientServerInteractionTest) { + MockHttp2Visitor client_visitor; + auto client_adapter = + OgHttp2Adapter::Create(client_visitor, {Perspective::kClient}); + MockHttp2Visitor server_visitor; + auto server_adapter = + OgHttp2Adapter::Create(server_visitor, {Perspective::kServer}); + + // Feeds bytes sent from the client into the server's ProcessBytes. + EXPECT_CALL(client_visitor, OnReadyToSend(_)) + .WillRepeatedly( + testing::Invoke(server_adapter.get(), &OgHttp2Adapter::ProcessBytes)); + // Feeds bytes sent from the server into the client's ProcessBytes. + EXPECT_CALL(server_visitor, OnReadyToSend(_)) + .WillRepeatedly( + testing::Invoke(client_adapter.get(), &OgHttp2Adapter::ProcessBytes)); + // Sets up the server to respond automatically to a request from a client. + EXPECT_CALL(server_visitor, OnEndHeadersForStream(_)) + .WillRepeatedly([&server_adapter](Http2StreamId stream_id) { + server_adapter->SubmitResponse( + stream_id, ToHeaders({{":status", "200"}}), nullptr); + server_adapter->Send(); + return true; + }); + // Sets up the client to create a new stream automatically when receiving a + // response. + EXPECT_CALL(client_visitor, OnEndHeadersForStream(_)) + .WillRepeatedly([&client_adapter, + &client_visitor](Http2StreamId stream_id) { + if (stream_id < 10) { + const Http2StreamId new_stream_id = stream_id + 2; + auto body = + absl::make_unique(client_visitor, true); + body->AppendPayload("This is an example request body."); + body->EndData(); + const int created_stream_id = client_adapter->SubmitRequest( + ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", + absl::StrCat("/this/is/request/", new_stream_id)}}), + std::move(body), nullptr); + EXPECT_EQ(new_stream_id, created_stream_id); + client_adapter->Send(); + } + return true; + }); + + // Submit a request to ensure the first stream is created. + int stream_id = client_adapter->SubmitRequest( + ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + nullptr, nullptr); + EXPECT_EQ(stream_id, 1); + + client_adapter->Send(); +} + +TEST(OgHttp2AdapterServerTest, ServerForbidsNewStreamBelowWatermark) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(3, + {{":method", "POST"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/false) + .Data(3, "This is the request body.") + .Headers(1, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "POST")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 25, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(3, 25)); + EXPECT_CALL(visitor, OnDataForStream(3, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kInvalidNewStreamId)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_EQ(3, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, ServerForbidsWindowUpdateOnIdleStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence().ClientPreface().WindowUpdate(1, 42).Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, ServerForbidsDataOnIdleStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Data(1, "Sorry, out of order") + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, DATA, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, ServerForbidsRstStreamOnIdleStream) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + + EXPECT_EQ(0, adapter->GetHighestReceivedStreamId()); + + const std::string frames = + TestFrameSequence() + .ClientPreface() + .RstStream(1, Http2ErrorCode::ENHANCE_YOUR_CALM) + .Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnConnectionError(ConnectionError::kWrongFrameSequence)); + + const int64_t result = adapter->ProcessBytes(frames); + EXPECT_LT(result, 0); + + EXPECT_EQ(1, adapter->GetHighestReceivedStreamId()); + + EXPECT_TRUE(adapter->want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + int send_result = adapter->Send(); + // Some bytes should have been serialized. + EXPECT_EQ(0, send_result); + // SETTINGS, SETTINGS ack, and GOAWAY. + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, ServerForbidsNewStreamAboveStreamLimit) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client send a SETTINGS ack and then attempt to open more than the + // advertised number of streams. The overflow stream should be rejected. + const std::string stream_frames = + TestFrameSequence() + .SettingsAck() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0x1)); + EXPECT_CALL(visitor, OnSettingsAck()); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5)); + EXPECT_CALL( + visitor, + OnInvalidFrame(3, Http2VisitorInterface::InvalidFrameError::kProtocol)); + // The oghttp2 stack also signals the connection error via OnConnectionError() + // and a negative ProcessBytes() return value. + EXPECT_CALL(visitor, + OnConnectionError(Http2VisitorInterface::ConnectionError:: + kExceededMaxConcurrentStreams)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_LT(stream_result, 0); + + // The server should send a GOAWAY for this error, even though + // OnInvalidFrame() returns true. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(GOAWAY, 0, _, 0x0, + static_cast(Http2ErrorCode::PROTOCOL_ERROR))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2AdapterServerTest, + ServerRstStreamsNewStreamAboveStreamLimitBeforeAck) { + DataSavingVisitor visitor; + OgHttp2Adapter::Options options{.perspective = Perspective::kServer}; + auto adapter = OgHttp2Adapter::Create(visitor, options); + adapter->SubmitSettings({{MAX_CONCURRENT_STREAMS, 1}}); + + const std::string initial_frames = + TestFrameSequence().ClientPreface().Serialize(); + + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = adapter->ProcessBytes(initial_frames); + EXPECT_EQ(static_cast(initial_result), initial_frames.size()); + + EXPECT_TRUE(adapter->want_write()); + + // Server initial SETTINGS (with MAX_CONCURRENT_STREAMS) and SETTINGS ack. + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 6, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 6, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, 0, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, 0, 0x1, 0)); + + int send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::SETTINGS})); + visitor.Clear(); + + // Let the client avoid sending a SETTINGS ack and attempt to open more than + // the advertised number of streams. The server should still reject the + // overflow stream, albeit with RST_STREAM REFUSED_STREAM instead of GOAWAY. + const std::string stream_frames = + TestFrameSequence() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Headers(3, + {{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}, + /*fin=*/true) + .Serialize(); + + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, _, _)).Times(4); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 0x5)); + EXPECT_CALL(visitor, + OnInvalidFrame( + 3, Http2VisitorInterface::InvalidFrameError::kRefusedStream)); + + const int64_t stream_result = adapter->ProcessBytes(stream_frames); + EXPECT_EQ(static_cast(stream_result), stream_frames.size()); + + // The server sends a RST_STREAM for the offending stream. + EXPECT_TRUE(adapter->want_write()); + EXPECT_CALL(visitor, OnBeforeFrameSent(RST_STREAM, 3, _, 0x0)); + EXPECT_CALL(visitor, + OnFrameSent(RST_STREAM, 3, _, 0x0, + static_cast(Http2ErrorCode::REFUSED_STREAM))); + + send_result = adapter->Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::RST_STREAM})); } } // namespace diff --git a/gquiche/http2/adapter/oghttp2_session.cc b/gquiche/http2/adapter/oghttp2_session.cc index db82b9d6..b723ae91 100644 --- a/gquiche/http2/adapter/oghttp2_session.cc +++ b/gquiche/http2/adapter/oghttp2_session.cc @@ -1,27 +1,284 @@ #include "gquiche/http2/adapter/oghttp2_session.h" +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/escaping.h" +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/http2_util.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/adapter/oghttp2_util.h" +#include "gquiche/spdy/core/spdy_protocol.h" + namespace http2 { namespace adapter { +namespace { + +using ConnectionError = Http2VisitorInterface::ConnectionError; +using SpdyFramerError = Http2DecoderAdapter::SpdyFramerError; + +using ::spdy::SpdySettingsIR; + +// #define OGHTTP2_DEBUG_TRACE 1 + +#ifdef OGHTTP2_DEBUG_TRACE +const bool kTraceLoggingEnabled = true; +#else +const bool kTraceLoggingEnabled = false; +#endif + +const uint32_t kMaxAllowedMetadataFrameSize = 65536u; +const uint32_t kDefaultHpackTableCapacity = 4096u; +const uint32_t kMaximumHpackTableCapacity = 65536u; + +// TODO(birenroy): Consider incorporating spdy::FlagsSerializionVisitor here. +class FrameAttributeCollector : public spdy::SpdyFrameVisitor { + public: + FrameAttributeCollector() = default; + void VisitData(const spdy::SpdyDataIR& data) override { + frame_type_ = static_cast(data.frame_type()); + stream_id_ = data.stream_id(); + flags_ = (data.fin() ? 0x1 : 0) | (data.padded() ? 0x8 : 0); + } + void VisitHeaders(const spdy::SpdyHeadersIR& headers) override { + frame_type_ = static_cast(headers.frame_type()); + stream_id_ = headers.stream_id(); + flags_ = 0x4 | (headers.fin() ? 0x1 : 0) | (headers.padded() ? 0x8 : 0) | + (headers.has_priority() ? 0x20 : 0); + } + void VisitPriority(const spdy::SpdyPriorityIR& priority) override { + frame_type_ = static_cast(priority.frame_type()); + frame_type_ = 2; + stream_id_ = priority.stream_id(); + } + void VisitRstStream(const spdy::SpdyRstStreamIR& rst_stream) override { + frame_type_ = static_cast(rst_stream.frame_type()); + frame_type_ = 3; + stream_id_ = rst_stream.stream_id(); + error_code_ = rst_stream.error_code(); + } + void VisitSettings(const spdy::SpdySettingsIR& settings) override { + frame_type_ = static_cast(settings.frame_type()); + frame_type_ = 4; + flags_ = (settings.is_ack() ? 0x1 : 0); + } + void VisitPushPromise(const spdy::SpdyPushPromiseIR& push_promise) override { + frame_type_ = static_cast(push_promise.frame_type()); + frame_type_ = 5; + stream_id_ = push_promise.stream_id(); + flags_ = (push_promise.padded() ? 0x8 : 0); + } + void VisitPing(const spdy::SpdyPingIR& ping) override { + frame_type_ = static_cast(ping.frame_type()); + frame_type_ = 6; + flags_ = (ping.is_ack() ? 0x1 : 0); + } + void VisitGoAway(const spdy::SpdyGoAwayIR& goaway) override { + frame_type_ = static_cast(goaway.frame_type()); + frame_type_ = 7; + error_code_ = goaway.error_code(); + } + void VisitWindowUpdate( + const spdy::SpdyWindowUpdateIR& window_update) override { + frame_type_ = static_cast(window_update.frame_type()); + frame_type_ = 8; + stream_id_ = window_update.stream_id(); + } + void VisitContinuation( + const spdy::SpdyContinuationIR& continuation) override { + frame_type_ = static_cast(continuation.frame_type()); + stream_id_ = continuation.stream_id(); + flags_ = continuation.end_headers() ? 0x4 : 0; + } + void VisitUnknown(const spdy::SpdyUnknownIR& unknown) override { + frame_type_ = static_cast(unknown.frame_type()); + stream_id_ = unknown.stream_id(); + flags_ = unknown.flags(); + } + void VisitAltSvc(const spdy::SpdyAltSvcIR& /*altsvc*/) override {} + void VisitPriorityUpdate( + const spdy::SpdyPriorityUpdateIR& /*priority_update*/) override {} + void VisitAcceptCh(const spdy::SpdyAcceptChIR& /*accept_ch*/) override {} + + uint32_t stream_id() { return stream_id_; } + uint32_t error_code() { return error_code_; } + uint8_t frame_type() { return frame_type_; } + uint8_t flags() { return flags_; } + + private: + uint32_t stream_id_ = 0; + uint32_t error_code_ = 0; + uint8_t frame_type_ = 0; + uint8_t flags_ = 0; +}; + +absl::string_view TracePerspectiveAsString(Perspective p) { + switch (p) { + case Perspective::kClient: + return "OGHTTP2_CLIENT"; + case Perspective::kServer: + return "OGHTTP2_SERVER"; + } + return "OGHTTP2_SERVER"; +} + +class RunOnExit { + public: + RunOnExit() = default; + explicit RunOnExit(std::function f) : f_(std::move(f)) {} + + RunOnExit(const RunOnExit& other) = delete; + RunOnExit& operator=(const RunOnExit& other) = delete; + RunOnExit(RunOnExit&& other) = delete; + RunOnExit& operator=(RunOnExit&& other) = delete; + + ~RunOnExit() { + if (f_) { + f_(); + } + f_ = {}; + } + + void emplace(std::function f) { f_ = std::move(f); } + + private: + std::function f_; +}; + +Http2ErrorCode GetHttp2ErrorCode(SpdyFramerError error) { + switch (error) { + case SpdyFramerError::SPDY_NO_ERROR: + return Http2ErrorCode::HTTP2_NO_ERROR; + case SpdyFramerError::SPDY_INVALID_STREAM_ID: + case SpdyFramerError::SPDY_INVALID_CONTROL_FRAME: + case SpdyFramerError::SPDY_INVALID_PADDING: + case SpdyFramerError::SPDY_INVALID_DATA_FRAME_FLAGS: + case SpdyFramerError::SPDY_UNEXPECTED_FRAME: + return Http2ErrorCode::PROTOCOL_ERROR; + case SpdyFramerError::SPDY_CONTROL_PAYLOAD_TOO_LARGE: + case SpdyFramerError::SPDY_INVALID_CONTROL_FRAME_SIZE: + case SpdyFramerError::SPDY_OVERSIZED_PAYLOAD: + return Http2ErrorCode::FRAME_SIZE_ERROR; + case SpdyFramerError::SPDY_DECOMPRESS_FAILURE: + case SpdyFramerError::SPDY_HPACK_INDEX_VARINT_ERROR: + case SpdyFramerError::SPDY_HPACK_NAME_LENGTH_VARINT_ERROR: + case SpdyFramerError::SPDY_HPACK_VALUE_LENGTH_VARINT_ERROR: + case SpdyFramerError::SPDY_HPACK_NAME_TOO_LONG: + case SpdyFramerError::SPDY_HPACK_VALUE_TOO_LONG: + case SpdyFramerError::SPDY_HPACK_NAME_HUFFMAN_ERROR: + case SpdyFramerError::SPDY_HPACK_VALUE_HUFFMAN_ERROR: + case SpdyFramerError::SPDY_HPACK_MISSING_DYNAMIC_TABLE_SIZE_UPDATE: + case SpdyFramerError::SPDY_HPACK_INVALID_INDEX: + case SpdyFramerError::SPDY_HPACK_INVALID_NAME_INDEX: + case SpdyFramerError::SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_NOT_ALLOWED: + case SpdyFramerError:: + SPDY_HPACK_INITIAL_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_LOW_WATER_MARK: + case SpdyFramerError:: + SPDY_HPACK_DYNAMIC_TABLE_SIZE_UPDATE_IS_ABOVE_ACKNOWLEDGED_SETTING: + case SpdyFramerError::SPDY_HPACK_TRUNCATED_BLOCK: + case SpdyFramerError::SPDY_HPACK_FRAGMENT_TOO_LONG: + case SpdyFramerError::SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT: + return Http2ErrorCode::COMPRESSION_ERROR; + case SpdyFramerError::SPDY_INTERNAL_FRAMER_ERROR: + case SpdyFramerError::SPDY_STOP_PROCESSING: + case SpdyFramerError::LAST_ERROR: + return Http2ErrorCode::INTERNAL_ERROR; + } + return Http2ErrorCode::INTERNAL_ERROR; +} + +bool IsResponse(HeaderType type) { + return type == HeaderType::RESPONSE_100 || type == HeaderType::RESPONSE; +} + +bool StatusIs1xx(absl::string_view status) { + return status.size() == 3 && status[0] == '1'; +} + +// Returns the upper bound on HPACK encoder table capacity. If not specified in +// the Options, a reasonable default upper bound is used. +uint32_t HpackCapacityBound(const OgHttp2Session::Options& o) { + return o.max_hpack_encoding_table_capacity.value_or( + kMaximumHpackTableCapacity); +} + +} // namespace + void OgHttp2Session::PassthroughHeadersHandler::OnHeaderBlockStart() { - visitor_.OnBeginHeadersForStream(stream_id_); + result_ = Http2VisitorInterface::HEADER_OK; + const bool status = visitor_.OnBeginHeadersForStream(stream_id_); + if (!status) { + result_ = Http2VisitorInterface::HEADER_CONNECTION_ERROR; + } + validator_.StartHeaderBlock(); } void OgHttp2Session::PassthroughHeadersHandler::OnHeader( absl::string_view key, absl::string_view value) { - visitor_.OnHeaderForStream(stream_id_, key, value); + if (result_ != Http2VisitorInterface::HEADER_OK) { + QUICHE_VLOG(2) << "Early return; status not HEADER_OK"; + return; + } + const auto validation_result = validator_.ValidateSingleHeader(key, value); + if (validation_result == HeaderValidator::HEADER_VALUE_INVALID_STATUS) { + QUICHE_VLOG(2) << "RST_STREAM: invalid status found"; + result_ = Http2VisitorInterface::HEADER_HTTP_MESSAGING; + return; + } else if (validation_result != HeaderValidator::HEADER_OK) { + QUICHE_VLOG(2) << "RST_STREAM: invalid header found"; + // TODO(birenroy): consider updating this to return HEADER_HTTP_MESSAGING. + result_ = Http2VisitorInterface::HEADER_RST_STREAM; + return; + } + result_ = visitor_.OnHeaderForStream(stream_id_, key, value); } void OgHttp2Session::PassthroughHeadersHandler::OnHeaderBlockEnd( size_t /* uncompressed_header_bytes */, size_t /* compressed_header_bytes */) { - visitor_.OnEndHeadersForStream(stream_id_); + if (result_ == Http2VisitorInterface::HEADER_OK) { + if (!validator_.FinishHeaderBlock(type_)) { + result_ = Http2VisitorInterface::HEADER_RST_STREAM; + } + } + if (frame_contains_fin_ && IsResponse(type_) && + StatusIs1xx(status_header())) { + // Unexpected end of stream without final headers. + result_ = Http2VisitorInterface::HEADER_HTTP_MESSAGING; + } + if (result_ == Http2VisitorInterface::HEADER_OK) { + const bool result = visitor_.OnEndHeadersForStream(stream_id_); + if (!result) { + session_.decoder_.StopProcessing(); + } + } else { + session_.OnHeaderStatus(stream_id_, result_); + } + frame_contains_fin_ = false; } OgHttp2Session::OgHttp2Session(Http2VisitorInterface& visitor, Options options) - : visitor_(visitor), headers_handler_(visitor), options_(options) { - decoder_.set_visitor(this); + : visitor_(visitor), + event_forwarder_([this]() { return !latched_error_; }, *this), + receive_logger_( + &event_forwarder_, TracePerspectiveAsString(options.perspective), + []() { return kTraceLoggingEnabled; }, this), + send_logger_( + TracePerspectiveAsString(options.perspective), + []() { return kTraceLoggingEnabled; }, this), + headers_handler_(*this, visitor), + noop_headers_handler_(/*listener=*/nullptr), + connection_window_manager_(kInitialFlowControlWindowSize, + [this](size_t window_update_delta) { + SendWindowUpdate(kConnectionStreamId, + window_update_delta); + }), + options_(options) { + decoder_.set_visitor(&receive_logger_); + decoder_.set_extension_visitor(this); if (options_.perspective == Perspective::kServer) { remaining_preface_ = {spdy::kHttp2ConnectionHeaderPrefix, spdy::kHttp2ConnectionHeaderPrefixSize}; @@ -30,8 +287,93 @@ OgHttp2Session::OgHttp2Session(Http2VisitorInterface& visitor, Options options) OgHttp2Session::~OgHttp2Session() {} -ssize_t OgHttp2Session::ProcessBytes(absl::string_view bytes) { - ssize_t preface_consumed = 0; +void OgHttp2Session::SetStreamUserData(Http2StreamId stream_id, + void* user_data) { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + it->second.user_data = user_data; + } +} + +void* OgHttp2Session::GetStreamUserData(Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.user_data; + } + return nullptr; +} + +bool OgHttp2Session::ResumeStream(Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end() || it->second.outbound_body == nullptr || + !write_scheduler_.StreamRegistered(stream_id)) { + return false; + } + it->second.data_deferred = false; + write_scheduler_.MarkStreamReady(stream_id, /*add_to_front=*/false); + return true; +} + +int OgHttp2Session::GetStreamSendWindowSize(Http2StreamId stream_id) const { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.send_window; + } + return -1; +} + +int OgHttp2Session::GetStreamReceiveWindowLimit(Http2StreamId stream_id) const { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.window_manager.WindowSizeLimit(); + } + return -1; +} + +int OgHttp2Session::GetStreamReceiveWindowSize(Http2StreamId stream_id) const { + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + return it->second.window_manager.CurrentWindowSize(); + } + return -1; +} + +int OgHttp2Session::GetReceiveWindowSize() const { + return connection_window_manager_.CurrentWindowSize(); +} + +int OgHttp2Session::GetHpackEncoderDynamicTableSize() const { + const spdy::HpackEncoder* encoder = framer_.GetHpackEncoder(); + return encoder == nullptr ? 0 : encoder->GetDynamicTableSize(); +} + +int OgHttp2Session::GetHpackEncoderDynamicTableCapacity() const { + const spdy::HpackEncoder* encoder = framer_.GetHpackEncoder(); + return encoder == nullptr ? kDefaultHpackTableCapacity + : encoder->CurrentHeaderTableSizeSetting(); +} + +int OgHttp2Session::GetHpackDecoderDynamicTableSize() const { + const spdy::HpackDecoderAdapter* decoder = decoder_.GetHpackDecoder(); + return decoder == nullptr ? 0 : decoder->GetDynamicTableSize(); +} + +int OgHttp2Session::GetHpackDecoderSizeLimit() const { + const spdy::HpackDecoderAdapter* decoder = decoder_.GetHpackDecoder(); + return decoder == nullptr ? 0 : decoder->GetCurrentHeaderTableSizeSetting(); +} + +int64_t OgHttp2Session::ProcessBytes(absl::string_view bytes) { + QUICHE_VLOG(2) << TracePerspectiveAsString(options_.perspective) + << " processing [" << absl::CEscape(bytes) << "]"; + if (processing_bytes_) { + QUICHE_VLOG(1) << "Returning early; already processing bytes."; + return 0; + } + processing_bytes_ = true; + RunOnExit r{[this]() { processing_bytes_ = false; }}; + + int64_t preface_consumed = 0; if (!remaining_preface_.empty()) { QUICHE_VLOG(2) << "Preface bytes remaining: " << remaining_preface_.size(); // decoder_ does not understand the client connection preface. @@ -41,7 +383,8 @@ ssize_t OgHttp2Session::ProcessBytes(absl::string_view bytes) { QUICHE_DLOG(INFO) << "Preface doesn't match! Expected: [" << absl::CEscape(remaining_preface_) << "], actual: [" << absl::CEscape(bytes) << "]"; - visitor_.OnConnectionError(); + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidConnectionPreface); return -1; } remaining_preface_.remove_prefix(min_size); @@ -53,8 +396,14 @@ ssize_t OgHttp2Session::ProcessBytes(absl::string_view bytes) { } preface_consumed = min_size; } - ssize_t result = decoder_.ProcessInput(bytes.data(), bytes.size()); - return result < 0 ? result : result + preface_consumed; + int64_t result = decoder_.ProcessInput(bytes.data(), bytes.size()); + if (latched_error_) { + QUICHE_VLOG(2) << "ProcessBytes encountered an error."; + return -1; + } + const int64_t ret = result < 0 ? result : result + preface_consumed; + QUICHE_VLOG(2) << "ProcessBytes returning: " << ret; + return ret; } int OgHttp2Session::Consume(Http2StreamId stream_id, size_t num_bytes) { @@ -66,86 +415,557 @@ int OgHttp2Session::Consume(Http2StreamId stream_id, size_t num_bytes) { } else { it->second.window_manager.MarkDataFlushed(num_bytes); } + connection_window_manager_.MarkDataFlushed(num_bytes); return 0; // Remove? } +void OgHttp2Session::StartGracefulShutdown() { + if (options_.perspective == Perspective::kServer) { + if (!queued_goaway_) { + EnqueueFrame(absl::make_unique( + std::numeric_limits::max(), spdy::ERROR_CODE_NO_ERROR, + "graceful_shutdown")); + } + } else { + QUICHE_LOG(ERROR) << "Graceful shutdown not needed for clients."; + } +} + void OgHttp2Session::EnqueueFrame(std::unique_ptr frame) { + RunOnExit r; + if (frame->frame_type() == spdy::SpdyFrameType::GOAWAY) { + queued_goaway_ = true; + } else if (frame->fin() || + frame->frame_type() == spdy::SpdyFrameType::RST_STREAM) { + auto iter = stream_map_.find(frame->stream_id()); + if (iter != stream_map_.end()) { + iter->second.half_closed_local = true; + } + if (frame->frame_type() == spdy::SpdyFrameType::RST_STREAM) { + streams_reset_.insert(frame->stream_id()); + } else if (iter != stream_map_.end()) { + // Enqueue RST_STREAM NO_ERROR if appropriate. + r.emplace([this, iter]() { MaybeFinWithRstStream(iter); }); + } + } + if (frame->stream_id() != 0) { + auto result = queued_frames_.insert({frame->stream_id(), 1}); + if (!result.second) { + ++(result.first->second); + } + } frames_.push_back(std::move(frame)); } -std::string OgHttp2Session::GetBytesToWrite(absl::optional max_bytes) { - const size_t serialized_max = - max_bytes ? max_bytes.value() : std::numeric_limits::max(); - std::string serialized = std::move(serialized_prefix_); - while (serialized.size() < serialized_max && !frames_.empty()) { - spdy::SpdySerializedFrame frame = framer_.SerializeFrame(*frames_.front()); - absl::StrAppend(&serialized, absl::string_view(frame)); - frames_.pop_front(); +int OgHttp2Session::Send() { + if (sending_) { + QUICHE_VLOG(1) << TracePerspectiveAsString(options_.perspective) + << " returning early; already sending."; + return 0; + } + sending_ = true; + RunOnExit r{[this]() { sending_ = false; }}; + + MaybeSetupPreface(); + + SendResult continue_writing = SendQueuedFrames(); + while (continue_writing == SendResult::SEND_OK && + !connection_metadata_.empty()) { + continue_writing = SendMetadata(0, connection_metadata_); } - if (serialized.size() > serialized_max) { - serialized_prefix_ = serialized.substr(serialized_max); - serialized.resize(serialized_max); + // Wake streams for writes. + while (continue_writing == SendResult::SEND_OK && + write_scheduler_.HasReadyStreams() && connection_send_window_ > 0) { + const Http2StreamId stream_id = write_scheduler_.PopNextReadyStream(); + // TODO(birenroy): Add a return value to indicate write blockage, so streams + // aren't woken unnecessarily. + QUICHE_VLOG(1) << "Waking stream " << stream_id << " for writes."; + continue_writing = WriteForStream(stream_id); } - return serialized; + if (continue_writing == SendResult::SEND_OK) { + continue_writing = SendQueuedFrames(); + } + return continue_writing == SendResult::SEND_ERROR ? -1 : 0; } -void OgHttp2Session::OnError(http2::Http2DecoderAdapter::SpdyFramerError error, +OgHttp2Session::SendResult OgHttp2Session::MaybeSendBufferedData() { + int64_t result = std::numeric_limits::max(); + while (result > 0 && !buffered_data_.empty()) { + result = visitor_.OnReadyToSend(buffered_data_); + if (result > 0) { + buffered_data_.erase(0, result); + } + } + if (result < 0) { + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kSendError); + return SendResult::SEND_ERROR; + } + return buffered_data_.empty() ? SendResult::SEND_OK + : SendResult::SEND_BLOCKED; +} + +OgHttp2Session::SendResult OgHttp2Session::SendQueuedFrames() { + // Flush any serialized prefix. + const SendResult result = MaybeSendBufferedData(); + if (result != SendResult::SEND_OK) { + return result; + } + // Serialize and send frames in the queue. + while (!frames_.empty()) { + const auto& frame_ptr = frames_.front(); + FrameAttributeCollector c; + frame_ptr->Visit(&c); + // Frames can't accurately report their own length; the actual serialized + // length must be used instead. + spdy::SpdySerializedFrame frame = framer_.SerializeFrame(*frame_ptr); + const size_t frame_payload_length = frame.size() - spdy::kFrameHeaderSize; + frame_ptr->Visit(&send_logger_); + visitor_.OnBeforeFrameSent(c.frame_type(), c.stream_id(), + frame_payload_length, c.flags()); + const int64_t result = visitor_.OnReadyToSend(absl::string_view(frame)); + if (result < 0) { + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kSendError); + return SendResult::SEND_ERROR; + } else if (result == 0) { + // Write blocked. + return SendResult::SEND_BLOCKED; + } else { + AfterFrameSent(c.frame_type(), c.stream_id(), frame_payload_length, + c.flags(), c.error_code()); + + frames_.pop_front(); + if (static_cast(result) < frame.size()) { + // The frame was partially written, so the rest must be buffered. + buffered_data_.append(frame.data() + result, frame.size() - result); + return SendResult::SEND_BLOCKED; + } + } + } + return SendResult::SEND_OK; +} + +void OgHttp2Session::AfterFrameSent(uint8_t frame_type, uint32_t stream_id, + size_t payload_length, uint8_t flags, + uint32_t error_code) { + visitor_.OnFrameSent(frame_type, stream_id, payload_length, flags, + error_code); + if (stream_id == 0) { + const bool is_settings_ack = + static_cast(frame_type) == FrameType::SETTINGS && + (flags & 0x01); + if (is_settings_ack && encoder_header_table_capacity_when_acking_) { + framer_.UpdateHeaderEncoderTableSize( + encoder_header_table_capacity_when_acking_.value()); + encoder_header_table_capacity_when_acking_ = absl::nullopt; + } + return; + } + auto iter = queued_frames_.find(stream_id); + if (frame_type != 0) { + --iter->second; + } + if (iter->second == 0) { + // TODO(birenroy): Consider passing through `error_code` here. + CloseStreamIfReady(frame_type, stream_id); + } +} + +OgHttp2Session::SendResult OgHttp2Session::WriteForStream( + Http2StreamId stream_id) { + auto it = stream_map_.find(stream_id); + if (it == stream_map_.end()) { + QUICHE_LOG(ERROR) << "Can't find stream " << stream_id + << " which is ready to write!"; + return SendResult::SEND_OK; + } + StreamState& state = it->second; + SendResult connection_can_write = SendResult::SEND_OK; + if (!state.outbound_metadata.empty()) { + connection_can_write = SendMetadata(stream_id, state.outbound_metadata); + } + + if (state.outbound_body == nullptr) { + // No data to send, but there might be trailers. + if (state.trailers != nullptr) { + auto block_ptr = std::move(state.trailers); + if (state.half_closed_local) { + QUICHE_LOG(ERROR) << "Sent fin; can't send trailers."; + } else { + SendTrailers(stream_id, std::move(*block_ptr)); + } + } + return SendResult::SEND_OK; + } + int32_t available_window = + std::min({connection_send_window_, state.send_window, + static_cast(max_frame_payload_)}); + while (connection_can_write == SendResult::SEND_OK && available_window > 0 && + state.outbound_body != nullptr && !state.data_deferred) { + int64_t length; + bool end_data; + std::tie(length, end_data) = + state.outbound_body->SelectPayloadLength(available_window); + QUICHE_VLOG(2) << "WriteForStream | length: " << length + << " end_data: " << end_data + << " trailers: " << state.trailers.get(); + if (length == 0 && !end_data && + (options_.trailers_require_end_data || state.trailers == nullptr)) { + // An unproductive call to SelectPayloadLength() results in this stream + // entering the "deferred" state only if either no trailers are available + // to send, or trailers require an explicit end_data before being sent. + state.data_deferred = true; + break; + } else if (length == DataFrameSource::kError) { + // TODO(birenroy): Consider queuing a RST_STREAM INTERNAL_ERROR instead. + CloseStream(stream_id, Http2ErrorCode::INTERNAL_ERROR); + // No more work on the stream; it has been closed. + break; + } + const bool fin = end_data ? state.outbound_body->send_fin() : false; + if (length > 0 || fin) { + spdy::SpdyDataIR data(stream_id); + data.set_fin(fin); + data.SetDataShallow(length); + spdy::SpdySerializedFrame header = + spdy::SpdyFramer::SerializeDataFrameHeaderWithPaddingLengthField( + data); + QUICHE_DCHECK(buffered_data_.empty() && frames_.empty()); + const bool success = + state.outbound_body->Send(absl::string_view(header), length); + if (!success) { + connection_can_write = SendResult::SEND_BLOCKED; + break; + } + connection_send_window_ -= length; + state.send_window -= length; + available_window = std::min({connection_send_window_, state.send_window, + static_cast(max_frame_payload_)}); + if (fin) { + state.half_closed_local = true; + MaybeFinWithRstStream(it); + } + AfterFrameSent(/* DATA */ 0, stream_id, length, fin ? 0x1 : 0x0, 0); + if (!stream_map_.contains(stream_id)) { + // Note: the stream may have been closed if `fin` is true. + break; + } + } + if (end_data || (length == 0 && state.trailers != nullptr && + !options_.trailers_require_end_data)) { + // If SelectPayloadLength() returned {0, false}, and there are trailers to + // send, and the safety feature is disabled, it's okay to send the + // trailers. + if (state.trailers != nullptr) { + auto block_ptr = std::move(state.trailers); + if (fin) { + QUICHE_LOG(ERROR) << "Sent fin; can't send trailers."; + } else { + SendTrailers(stream_id, std::move(*block_ptr)); + } + } + state.outbound_body = nullptr; + } + } + // If the stream still exists and has data to send, it should be marked as + // ready in the write scheduler. + if (stream_map_.contains(stream_id) && !state.data_deferred && + state.send_window > 0 && state.outbound_body != nullptr) { + write_scheduler_.MarkStreamReady(stream_id, false); + } + // Streams can continue writing as long as the connection is not write-blocked + // and there is additional flow control quota available. + if (connection_can_write != SendResult::SEND_OK) { + return connection_can_write; + } + return available_window <= 0 ? SendResult::SEND_BLOCKED : SendResult::SEND_OK; +} + +OgHttp2Session::SendResult OgHttp2Session::SendMetadata( + Http2StreamId stream_id, OgHttp2Session::MetadataSequence& sequence) { + const uint32_t max_payload_size = + std::min(kMaxAllowedMetadataFrameSize, max_frame_payload_); + auto payload_buffer = absl::make_unique(max_payload_size); + while (!sequence.empty()) { + MetadataSource& source = *sequence.front(); + + int64_t written; + bool end_metadata; + std::tie(written, end_metadata) = + source.Pack(payload_buffer.get(), max_payload_size); + if (written < 0) { + // Did not touch the connection, so perhaps writes are still possible. + return SendResult::SEND_OK; + } + QUICHE_DCHECK_LE(static_cast(written), max_payload_size); + auto payload = absl::string_view( + reinterpret_cast(payload_buffer.get()), written); + EnqueueFrame(absl::make_unique( + stream_id, kMetadataFrameType, end_metadata ? kMetadataEndFlag : 0u, + std::string(payload))); + if (end_metadata) { + sequence.erase(sequence.begin()); + } + } + return SendQueuedFrames(); +} + +int32_t OgHttp2Session::SubmitRequest( + absl::Span headers, + std::unique_ptr data_source, void* user_data) { + // TODO(birenroy): return an error for the incorrect perspective + const Http2StreamId stream_id = next_stream_id_; + next_stream_id_ += 2; + if (CanCreateStream()) { + StartRequest(stream_id, ToHeaderBlock(headers), std::move(data_source), + user_data); + } else { + // TODO(diannahu): There should probably be a limit to the number of allowed + // pending streams. + pending_streams_.push_back( + {stream_id, ToHeaderBlock(headers), std::move(data_source), user_data}); + } + return stream_id; +} + +int OgHttp2Session::SubmitResponse( + Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source) { + // TODO(birenroy): return an error for the incorrect perspective + auto iter = stream_map_.find(stream_id); + if (iter == stream_map_.end()) { + QUICHE_LOG(ERROR) << "Unable to find stream " << stream_id; + return -501; // NGHTTP2_ERR_INVALID_ARGUMENT + } + const bool end_stream = data_source == nullptr; + if (!end_stream) { + // Add data source to stream state + iter->second.outbound_body = std::move(data_source); + write_scheduler_.MarkStreamReady(stream_id, false); + } + SendHeaders(stream_id, ToHeaderBlock(headers), end_stream); + return 0; +} + +int OgHttp2Session::SubmitTrailer(Http2StreamId stream_id, + absl::Span trailers) { + // TODO(birenroy): Reject trailers when acting as a client? + auto iter = stream_map_.find(stream_id); + if (iter == stream_map_.end()) { + QUICHE_LOG(ERROR) << "Unable to find stream " << stream_id; + return -501; // NGHTTP2_ERR_INVALID_ARGUMENT + } + StreamState& state = iter->second; + if (state.half_closed_local) { + QUICHE_LOG(ERROR) << "Stream " << stream_id << " is half closed (local)"; + return -514; // NGHTTP2_ERR_INVALID_STREAM_STATE + } + if (state.trailers != nullptr) { + QUICHE_LOG(ERROR) << "Stream " << stream_id + << " already has trailers queued"; + return -514; // NGHTTP2_ERR_INVALID_STREAM_STATE + } + if (state.outbound_body == nullptr) { + // Enqueue trailers immediately. + SendTrailers(stream_id, ToHeaderBlock(trailers)); + } else { + QUICHE_LOG_IF(ERROR, state.outbound_body->send_fin()) + << "DataFrameSource will send fin, preventing trailers!"; + // Save trailers so they can be written once data is done. + state.trailers = + absl::make_unique(ToHeaderBlock(trailers)); + if (!options_.trailers_require_end_data) { + iter->second.data_deferred = false; + } + if (!iter->second.data_deferred) { + write_scheduler_.MarkStreamReady(stream_id, false); + } + } + return 0; +} + +void OgHttp2Session::SubmitMetadata(Http2StreamId stream_id, + std::unique_ptr source) { + if (stream_id == 0) { + connection_metadata_.push_back(std::move(source)); + } else { + auto iter = CreateStream(stream_id); + iter->second.outbound_metadata.push_back(std::move(source)); + write_scheduler_.MarkStreamReady(stream_id, false); + } +} + +void OgHttp2Session::SubmitSettings(absl::Span settings) { + EnqueueFrame(PrepareSettingsFrame(settings)); +} + +void OgHttp2Session::OnError(SpdyFramerError error, std::string detailed_error) { QUICHE_VLOG(1) << "Error: " << http2::Http2DecoderAdapter::SpdyFramerErrorToString(error) << " details: " << detailed_error; - visitor_.OnConnectionError(); + // TODO(diannahu): Consider propagating `detailed_error`. + LatchErrorAndNotify(GetHttp2ErrorCode(error), ConnectionError::kParseError); } -void OgHttp2Session::OnCommonHeader(spdy::SpdyStreamId /*stream_id*/, - size_t /*length*/, - uint8_t /*type*/, - uint8_t /*flags*/) {} +void OgHttp2Session::OnCommonHeader(spdy::SpdyStreamId stream_id, + size_t length, + uint8_t type, + uint8_t flags) { + highest_received_stream_id_ = std::max(static_cast(stream_id), + highest_received_stream_id_); + const bool result = visitor_.OnFrameHeader(stream_id, length, type, flags); + if (!result) { + decoder_.StopProcessing(); + } +} void OgHttp2Session::OnDataFrameHeader(spdy::SpdyStreamId stream_id, - size_t length, - bool fin) { - visitor_.OnBeginDataForStream(stream_id, length); + size_t length, bool /*fin*/) { + if (!stream_map_.contains(stream_id)) { + // The stream does not exist; it could be an error or a benign close, e.g., + // getting data for a stream this connection recently closed. + if (static_cast(stream_id) > highest_processed_stream_id_) { + // Receiving DATA before HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kWrongFrameSequence); + } + return; + } + + const bool result = visitor_.OnBeginDataForStream(stream_id, length); + if (!result) { + decoder_.StopProcessing(); + } } void OgHttp2Session::OnStreamFrameData(spdy::SpdyStreamId stream_id, const char* data, size_t len) { - visitor_.OnDataForStream(stream_id, absl::string_view(data, len)); + // Count the data against flow control, even if the stream is unknown. + MarkDataBuffered(stream_id, len); + + if (!stream_map_.contains(stream_id)) { + // If the stream was unknown due to a protocol error, the visitor was + // informed in OnDataFrameHeader(). + return; + } + + const bool result = + visitor_.OnDataForStream(stream_id, absl::string_view(data, len)); + if (!result) { + decoder_.StopProcessing(); + } } void OgHttp2Session::OnStreamEnd(spdy::SpdyStreamId stream_id) { - visitor_.OnEndStream(stream_id); + auto iter = stream_map_.find(stream_id); + if (iter != stream_map_.end()) { + iter->second.half_closed_remote = true; + visitor_.OnEndStream(stream_id); + } + auto queued_frames_iter = queued_frames_.find(stream_id); + const bool no_queued_frames = queued_frames_iter == queued_frames_.end() || + queued_frames_iter->second == 0; + if (iter != stream_map_.end() && iter->second.half_closed_local && + options_.perspective == Perspective::kClient && no_queued_frames) { + // From the client's perspective, the stream can be closed if it's already + // half_closed_local. + CloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR); + } } -void OgHttp2Session::OnStreamPadLength(spdy::SpdyStreamId /*stream_id*/, - size_t /*value*/) {} +void OgHttp2Session::OnStreamPadLength(spdy::SpdyStreamId stream_id, + size_t value) { + MarkDataBuffered(stream_id, 1 + value); + // TODO(181586191): Pass padding to the visitor? +} -void OgHttp2Session::OnStreamPadding(spdy::SpdyStreamId stream_id, size_t len) { +void OgHttp2Session::OnStreamPadding(spdy::SpdyStreamId /*stream_id*/, size_t + /*len*/) { + // Flow control was accounted for in OnStreamPadLength(). + // TODO(181586191): Pass padding to the visitor? } spdy::SpdyHeadersHandlerInterface* OgHttp2Session::OnHeaderFrameStart( spdy::SpdyStreamId stream_id) { - headers_handler_.set_stream_id(stream_id); - return &headers_handler_; + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + headers_handler_.set_stream_id(stream_id); + headers_handler_.set_header_type( + NextHeaderType(it->second.received_header_type)); + return &headers_handler_; + } else { + return &noop_headers_handler_; + } } void OgHttp2Session::OnHeaderFrameEnd(spdy::SpdyStreamId stream_id) { - headers_handler_.set_stream_id(0); + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + if (headers_handler_.header_type() == HeaderType::RESPONSE && + !headers_handler_.status_header().empty() && + headers_handler_.status_header()[0] == '1') { + // If response headers carried a 1xx response code, final response headers + // should still be forthcoming. + it->second.received_header_type = HeaderType::RESPONSE_100; + } else { + it->second.received_header_type = headers_handler_.header_type(); + } + headers_handler_.set_stream_id(0); + } } void OgHttp2Session::OnRstStream(spdy::SpdyStreamId stream_id, spdy::SpdyErrorCode error_code) { + auto iter = stream_map_.find(stream_id); + if (iter != stream_map_.end()) { + iter->second.half_closed_remote = true; + iter->second.outbound_body = nullptr; + } else if (static_cast(stream_id) > + highest_processed_stream_id_) { + // Receiving RST_STREAM before HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kWrongFrameSequence); + return; + } visitor_.OnRstStream(stream_id, TranslateErrorCode(error_code)); - visitor_.OnAbortStream(stream_id, TranslateErrorCode(error_code)); + // TODO(birenroy): Consider whether there are outbound frames queued for the + // stream. + CloseStream(stream_id, TranslateErrorCode(error_code)); } void OgHttp2Session::OnSettings() { visitor_.OnSettingsStart(); + auto settings = absl::make_unique(); + settings->set_is_ack(true); + EnqueueFrame(std::move(settings)); } void OgHttp2Session::OnSetting(spdy::SpdySettingsId id, uint32_t value) { visitor_.OnSetting({id, value}); + if (id == kMetadataExtensionId) { + peer_supports_metadata_ = (value != 0); + } else if (id == MAX_FRAME_SIZE) { + max_frame_payload_ = value; + } else if (id == MAX_CONCURRENT_STREAMS) { + max_outbound_concurrent_streams_ = value; + } else if (id == HEADER_TABLE_SIZE) { + value = std::min(value, HpackCapacityBound(options_)); + if (value < framer_.GetHpackEncoder()->CurrentHeaderTableSizeSetting()) { + // Safe to apply a smaller table capacity immediately. + QUICHE_VLOG(2) << TracePerspectiveAsString(options_.perspective) + << " applying encoder table capacity " << value; + framer_.GetHpackEncoder()->ApplyHeaderTableSizeSetting(value); + } else { + QUICHE_VLOG(2) + << TracePerspectiveAsString(options_.perspective) + << " NOT applying encoder table capacity until writing ack: " + << value; + encoder_header_table_capacity_when_acking_ = value; + } + } } void OgHttp2Session::OnSettingsEnd() { @@ -153,53 +973,126 @@ void OgHttp2Session::OnSettingsEnd() { } void OgHttp2Session::OnSettingsAck() { + if (!settings_ack_callbacks_.empty()) { + SettingsAckCallback callback = std::move(settings_ack_callbacks_.front()); + settings_ack_callbacks_.pop_front(); + callback(); + } + visitor_.OnSettingsAck(); } void OgHttp2Session::OnPing(spdy::SpdyPingId unique_id, bool is_ack) { visitor_.OnPing(unique_id, is_ack); + if (options_.auto_ping_ack && !is_ack) { + auto ping = absl::make_unique(unique_id); + ping->set_is_ack(true); + EnqueueFrame(std::move(ping)); + } } void OgHttp2Session::OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, spdy::SpdyErrorCode error_code) { received_goaway_ = true; - visitor_.OnGoAway(last_accepted_stream_id, TranslateErrorCode(error_code), - ""); + const bool result = visitor_.OnGoAway(last_accepted_stream_id, + TranslateErrorCode(error_code), ""); + if (!result) { + decoder_.StopProcessing(); + } } -bool OgHttp2Session::OnGoAwayFrameData(const char* goaway_data, size_t len) { +bool OgHttp2Session::OnGoAwayFrameData(const char* /*goaway_data*/, size_t + /*len*/) { // Opaque data is currently ignored. return true; } void OgHttp2Session::OnHeaders(spdy::SpdyStreamId stream_id, - bool has_priority, - int weight, - spdy::SpdyStreamId parent_stream_id, - bool exclusive, - bool fin, - bool end) {} + bool /*has_priority*/, int /*weight*/, + spdy::SpdyStreamId /*parent_stream_id*/, + bool /*exclusive*/, bool fin, bool /*end*/) { + if (stream_id % 2 == 0) { + // Server push is disabled; receiving push HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidNewStreamId); + return; + } + if (fin) { + headers_handler_.set_frame_contains_fin(); + } + if (options_.perspective == Perspective::kServer) { + const auto new_stream_id = static_cast(stream_id); + if (new_stream_id <= highest_processed_stream_id_) { + // A new stream ID lower than the watermark is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidNewStreamId); + return; + } + + if (stream_map_.size() >= max_inbound_concurrent_streams_) { + // The new stream would exceed our advertised and acknowledged + // MAX_CONCURRENT_STREAMS. For parity with nghttp2, treat this error as a + // connection-level PROTOCOL_ERROR. + visitor_.OnInvalidFrame( + stream_id, Http2VisitorInterface::InvalidFrameError::kProtocol); + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kExceededMaxConcurrentStreams); + return; + } + if (stream_map_.size() >= pending_max_inbound_concurrent_streams_) { + // The new stream would exceed our advertised but unacked + // MAX_CONCURRENT_STREAMS. Refuse the stream for parity with nghttp2. + EnqueueFrame(absl::make_unique( + stream_id, spdy::ERROR_CODE_REFUSED_STREAM)); + const bool ok = visitor_.OnInvalidFrame( + stream_id, Http2VisitorInterface::InvalidFrameError::kRefusedStream); + if (!ok) { + LatchErrorAndNotify(Http2ErrorCode::REFUSED_STREAM, + ConnectionError::kExceededMaxConcurrentStreams); + } + return; + } + + CreateStream(stream_id); + } +} void OgHttp2Session::OnWindowUpdate(spdy::SpdyStreamId stream_id, int delta_window_size) { if (stream_id == 0) { - peer_window_ += delta_window_size; + connection_send_window_ += delta_window_size; } else { auto it = stream_map_.find(stream_id); if (it == stream_map_.end()) { QUICHE_VLOG(1) << "Stream " << stream_id << " not found!"; + if (static_cast(stream_id) > + highest_processed_stream_id_) { + // Receiving WINDOW_UPDATE before HEADERS is a connection error. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kWrongFrameSequence); + return; + } } else { + if (it->second.send_window == 0) { + // The stream was blocked on flow control. + write_scheduler_.MarkStreamReady(stream_id, false); + } it->second.send_window += delta_window_size; } } visitor_.OnWindowUpdate(stream_id, delta_window_size); } -void OgHttp2Session::OnPushPromise(spdy::SpdyStreamId stream_id, - spdy::SpdyStreamId promised_stream_id, - bool end) {} +void OgHttp2Session::OnPushPromise(spdy::SpdyStreamId /*stream_id*/, + spdy::SpdyStreamId /*promised_stream_id*/, + bool /*end*/) { + // Server push is disabled; PUSH_PROMISE is an invalid frame. + LatchErrorAndNotify(Http2ErrorCode::PROTOCOL_ERROR, + ConnectionError::kInvalidPushPromise); +} -void OgHttp2Session::OnContinuation(spdy::SpdyStreamId stream_id, bool end) {} +void OgHttp2Session::OnContinuation(spdy::SpdyStreamId /*stream_id*/, bool + /*end*/) {} void OgHttp2Session::OnAltSvc(spdy::SpdyStreamId /*stream_id*/, absl::string_view /*origin*/, @@ -207,18 +1100,285 @@ void OgHttp2Session::OnAltSvc(spdy::SpdyStreamId /*stream_id*/, AlternativeServiceVector& /*altsvc_vector*/) { } -void OgHttp2Session::OnPriority(spdy::SpdyStreamId stream_id, - spdy::SpdyStreamId parent_stream_id, - int weight, - bool exclusive) {} +void OgHttp2Session::OnPriority(spdy::SpdyStreamId /*stream_id*/, + spdy::SpdyStreamId /*parent_stream_id*/, + int /*weight*/, bool /*exclusive*/) {} -void OgHttp2Session::OnPriorityUpdate(spdy::SpdyStreamId prioritized_stream_id, - absl::string_view priority_field_value) {} +void OgHttp2Session::OnPriorityUpdate( + spdy::SpdyStreamId /*prioritized_stream_id*/, + absl::string_view /*priority_field_value*/) {} -bool OgHttp2Session::OnUnknownFrame(spdy::SpdyStreamId stream_id, - uint8_t frame_type) { +bool OgHttp2Session::OnUnknownFrame(spdy::SpdyStreamId /*stream_id*/, + uint8_t /*frame_type*/) { return true; } +void OgHttp2Session::OnHeaderStatus( + Http2StreamId stream_id, Http2VisitorInterface::OnHeaderResult result) { + QUICHE_DCHECK_NE(result, Http2VisitorInterface::HEADER_OK); + const bool should_reset_stream = + result == Http2VisitorInterface::HEADER_RST_STREAM || + result == Http2VisitorInterface::HEADER_HTTP_MESSAGING; + if (should_reset_stream) { + const Http2ErrorCode error_code = + (result == Http2VisitorInterface::HEADER_RST_STREAM) + ? Http2ErrorCode::INTERNAL_ERROR + : Http2ErrorCode::PROTOCOL_ERROR; + const spdy::SpdyErrorCode spdy_error_code = TranslateErrorCode(error_code); + const Http2VisitorInterface::InvalidFrameError frame_error = + (result == Http2VisitorInterface::HEADER_RST_STREAM) + ? Http2VisitorInterface::InvalidFrameError::kHttpHeader + : Http2VisitorInterface::InvalidFrameError::kHttpMessaging; + auto it = streams_reset_.find(stream_id); + if (it == streams_reset_.end()) { + EnqueueFrame( + absl::make_unique(stream_id, spdy_error_code)); + + const bool ok = visitor_.OnInvalidFrame(stream_id, frame_error); + if (!ok) { + LatchErrorAndNotify(error_code, ConnectionError::kHeaderError); + } + } + } else if (result == Http2VisitorInterface::HEADER_CONNECTION_ERROR) { + LatchErrorAndNotify(Http2ErrorCode::INTERNAL_ERROR, + ConnectionError::kHeaderError); + } +} + +bool OgHttp2Session::OnFrameHeader(spdy::SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + if (type == kMetadataFrameType) { + QUICHE_DCHECK_EQ(metadata_length_, 0u); + visitor_.OnBeginMetadataForStream(stream_id, length); + metadata_stream_id_ = stream_id; + metadata_length_ = length; + end_metadata_ = flags & kMetadataEndFlag; + return true; + } else { + QUICHE_DLOG(INFO) << "Unexpected frame type " << static_cast(type) + << " received by the extension visitor."; + return false; + } +} + +void OgHttp2Session::OnFramePayload(const char* data, size_t len) { + if (metadata_length_ > 0) { + QUICHE_DCHECK_LE(len, metadata_length_); + const bool success = visitor_.OnMetadataForStream( + metadata_stream_id_, absl::string_view(data, len)); + if (success) { + metadata_length_ -= len; + if (metadata_length_ == 0 && end_metadata_) { + visitor_.OnMetadataEndForStream(metadata_stream_id_); + metadata_stream_id_ = 0; + end_metadata_ = false; + } + } else { + decoder_.StopProcessing(); + } + } else { + QUICHE_DLOG(INFO) << "Unexpected metadata payload for stream " + << metadata_stream_id_; + } +} + +void OgHttp2Session::MaybeSetupPreface() { + if (!queued_preface_) { + if (options_.perspective == Perspective::kClient) { + buffered_data_.assign(spdy::kHttp2ConnectionHeaderPrefix, + spdy::kHttp2ConnectionHeaderPrefixSize); + } + // First frame must be a non-ack SETTINGS. + if (frames_.empty() || + frames_.front()->frame_type() != spdy::SpdyFrameType::SETTINGS || + reinterpret_cast(*frames_.front()).is_ack()) { + frames_.push_front(PrepareSettingsFrame(GetInitialSettings())); + } + queued_preface_ = true; + } +} + +std::vector OgHttp2Session::GetInitialSettings() const { + std::vector settings; + if (!IsServerSession()) { + // Disable server push. Note that server push from clients is already + // disabled, so the server does not need to send this disabling setting. + // TODO(diannahu): Consider applying server push disabling on SETTINGS ack. + settings.push_back({Http2KnownSettingsId::ENABLE_PUSH, 0}); + } + return settings; +} + +std::unique_ptr OgHttp2Session::PrepareSettingsFrame( + absl::Span settings) { + auto settings_ir = absl::make_unique(); + for (const Http2Setting& setting : settings) { + settings_ir->AddSetting(setting.id, setting.value); + + if (setting.id == Http2KnownSettingsId::MAX_CONCURRENT_STREAMS) { + pending_max_inbound_concurrent_streams_ = setting.value; + } + } + + // Copy the (small) map of settings we are about to send so that we can set + // values in the SETTINGS ack callback. + settings_ack_callbacks_.push_back( + [this, settings_map = settings_ir->values()]() { + for (const auto id_and_value : settings_map) { + if (id_and_value.first == spdy::SETTINGS_MAX_CONCURRENT_STREAMS) { + max_inbound_concurrent_streams_ = id_and_value.second; + } else if (id_and_value.first == spdy::SETTINGS_HEADER_TABLE_SIZE) { + decoder_.GetHpackDecoder()->ApplyHeaderTableSizeSetting( + id_and_value.second); + } + } + }); + return settings_ir; +} + +void OgHttp2Session::SendWindowUpdate(Http2StreamId stream_id, + size_t update_delta) { + EnqueueFrame( + absl::make_unique(stream_id, update_delta)); +} + +void OgHttp2Session::SendHeaders(Http2StreamId stream_id, + spdy::SpdyHeaderBlock headers, + bool end_stream) { + auto frame = + absl::make_unique(stream_id, std::move(headers)); + frame->set_fin(end_stream); + EnqueueFrame(std::move(frame)); +} + +void OgHttp2Session::SendTrailers(Http2StreamId stream_id, + spdy::SpdyHeaderBlock trailers) { + auto frame = + absl::make_unique(stream_id, std::move(trailers)); + frame->set_fin(true); + EnqueueFrame(std::move(frame)); +} + +void OgHttp2Session::MaybeFinWithRstStream(StreamStateMap::iterator iter) { + QUICHE_DCHECK(iter != stream_map_.end() && iter->second.half_closed_local); + + if (options_.rst_stream_no_error_when_incomplete && + options_.perspective == Perspective::kServer && + !iter->second.half_closed_remote) { + // Since the peer has not yet ended the stream, this endpoint should + // send a RST_STREAM NO_ERROR. See RFC 7540 Section 8.1. + EnqueueFrame(absl::make_unique( + iter->first, spdy::SpdyErrorCode::ERROR_CODE_NO_ERROR)); + iter->second.half_closed_remote = true; + } +} + +void OgHttp2Session::MarkDataBuffered(Http2StreamId stream_id, size_t bytes) { + connection_window_manager_.MarkDataBuffered(bytes); + auto it = stream_map_.find(stream_id); + if (it != stream_map_.end()) { + it->second.window_manager.MarkDataBuffered(bytes); + } +} + +OgHttp2Session::StreamStateMap::iterator OgHttp2Session::CreateStream( + Http2StreamId stream_id) { + WindowManager::WindowUpdateListener listener = + [this, stream_id](size_t window_update_delta) { + SendWindowUpdate(stream_id, window_update_delta); + }; + absl::flat_hash_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = stream_map_.try_emplace( + stream_id, + StreamState(stream_receive_window_limit_, std::move(listener))); + if (inserted) { + // Add the stream to the write scheduler. + const WriteScheduler::StreamPrecedenceType precedence(3); + write_scheduler_.RegisterStream(stream_id, precedence); + + highest_processed_stream_id_ = + std::max(highest_processed_stream_id_, stream_id); + } + return iter; +} + +void OgHttp2Session::StartRequest(Http2StreamId stream_id, + spdy::SpdyHeaderBlock headers, + std::unique_ptr data_source, + void* user_data) { + auto iter = CreateStream(stream_id); + const bool end_stream = data_source == nullptr; + if (!end_stream) { + iter->second.outbound_body = std::move(data_source); + write_scheduler_.MarkStreamReady(stream_id, false); + } + iter->second.user_data = user_data; + SendHeaders(stream_id, std::move(headers), end_stream); +} + +void OgHttp2Session::CloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + visitor_.OnCloseStream(stream_id, error_code); + stream_map_.erase(stream_id); + if (write_scheduler_.StreamRegistered(stream_id)) { + write_scheduler_.UnregisterStream(stream_id); + } + + if (!pending_streams_.empty() && CanCreateStream()) { + PendingStreamState& pending_stream = pending_streams_.front(); + StartRequest(pending_stream.stream_id, std::move(pending_stream.headers), + std::move(pending_stream.data_source), + pending_stream.user_data); + pending_streams_.pop_front(); + } +} + +bool OgHttp2Session::CanCreateStream() const { + return stream_map_.size() < max_outbound_concurrent_streams_; +} + +HeaderType OgHttp2Session::NextHeaderType( + absl::optional current_type) { + if (IsServerSession()) { + return HeaderType::REQUEST; + } else if (!current_type || + current_type.value() == HeaderType::RESPONSE_100) { + return HeaderType::RESPONSE; + } else { + return HeaderType::RESPONSE_TRAILER; + } +} + +void OgHttp2Session::LatchErrorAndNotify(Http2ErrorCode error_code, + ConnectionError error) { + if (latched_error_) { + // Do not kick a connection when it is down. + return; + } + + latched_error_ = true; + visitor_.OnConnectionError(error); + decoder_.StopProcessing(); + if (IsServerSession()) { + EnqueueFrame(absl::make_unique( + highest_processed_stream_id_, TranslateErrorCode(error_code), + ConnectionErrorToString(error))); + } +} + +void OgHttp2Session::CloseStreamIfReady(uint8_t frame_type, + uint32_t stream_id) { + auto iter = stream_map_.find(stream_id); + if (iter == stream_map_.end()) { + return; + } + const StreamState& state = iter->second; + if (static_cast(frame_type) == FrameType::RST_STREAM || + (state.half_closed_local && state.half_closed_remote)) { + CloseStream(stream_id, Http2ErrorCode::HTTP2_NO_ERROR); + } +} + } // namespace adapter } // namespace http2 diff --git a/gquiche/http2/adapter/oghttp2_session.h b/gquiche/http2/adapter/oghttp2_session.h index 440ccb9b..906b81ca 100644 --- a/gquiche/http2/adapter/oghttp2_session.h +++ b/gquiche/http2/adapter/oghttp2_session.h @@ -1,47 +1,131 @@ #ifndef QUICHE_HTTP2_ADAPTER_OGHTTP2_SESSION_H_ #define QUICHE_HTTP2_ADAPTER_OGHTTP2_SESSION_H_ +#include +#include #include +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "gquiche/http2/adapter/data_source.h" +#include "gquiche/http2/adapter/event_forwarder.h" +#include "gquiche/http2/adapter/header_validator.h" +#include "gquiche/http2/adapter/http2_protocol.h" #include "gquiche/http2/adapter/http2_session.h" #include "gquiche/http2/adapter/http2_util.h" #include "gquiche/http2/adapter/http2_visitor_interface.h" #include "gquiche/http2/adapter/window_manager.h" +#include "gquiche/http2/core/http2_trace_logging.h" +#include "gquiche/http2/core/priority_write_scheduler.h" #include "gquiche/common/platform/api/quiche_bug_tracker.h" +#include "gquiche/common/platform/api/quiche_export.h" #include "gquiche/spdy/core/http2_frame_decoder_adapter.h" +#include "gquiche/spdy/core/no_op_headers_handler.h" #include "gquiche/spdy/core/spdy_framer.h" +#include "gquiche/spdy/core/spdy_header_block.h" +#include "gquiche/spdy/core/spdy_protocol.h" namespace http2 { namespace adapter { // This class manages state associated with a single multiplexed HTTP/2 session. -class OgHttp2Session : public Http2Session, - public spdy::SpdyFramerVisitorInterface { +class QUICHE_EXPORT_PRIVATE OgHttp2Session + : public Http2Session, + public spdy::SpdyFramerVisitorInterface, + public spdy::ExtensionVisitorInterface { public: - struct Options { + struct QUICHE_EXPORT_PRIVATE Options { Perspective perspective = Perspective::kClient; + // The maximum HPACK table size to use. + absl::optional max_hpack_encoding_table_capacity = absl::nullopt; + // Whether to automatically send PING acks when receiving a PING. + bool auto_ping_ack = true; + // Whether (as server) to send a RST_STREAM NO_ERROR when sending a fin on + // an incomplete stream. + bool rst_stream_no_error_when_incomplete = false; + // Whether (as server) to queue trailers until after a stream's data source + // has indicated the end of data. If false, the server will assume that + // submitting trailers indicates the end of data. + bool trailers_require_end_data = false; }; - OgHttp2Session(Http2VisitorInterface& visitor, Options /*options*/); + OgHttp2Session(Http2VisitorInterface& visitor, Options options); ~OgHttp2Session() override; // Enqueues a frame for transmission to the peer. void EnqueueFrame(std::unique_ptr frame); - // If |want_write()| returns true, this method will return a non-empty string - // containing serialized HTTP/2 frames to write to the peer. - std::string GetBytesToWrite(absl::optional max_bytes); + // Starts a graceful shutdown sequence. No-op if a GOAWAY has already been + // sent. + void StartGracefulShutdown(); + + // Invokes the visitor's OnReadyToSend() method for serialized frames and + // DataFrameSource::Send() for data frames. + int Send(); + + int32_t SubmitRequest(absl::Span headers, + std::unique_ptr data_source, + void* user_data); + int SubmitResponse(Http2StreamId stream_id, absl::Span headers, + std::unique_ptr data_source); + int SubmitTrailer(Http2StreamId stream_id, absl::Span trailers); + void SubmitMetadata(Http2StreamId stream_id, + std::unique_ptr source); + void SubmitSettings(absl::Span settings); + + bool IsServerSession() const { + return options_.perspective == Perspective::kServer; + } + Http2StreamId GetHighestReceivedStreamId() const { + return highest_received_stream_id_; + } + void SetStreamUserData(Http2StreamId stream_id, void* user_data); + void* GetStreamUserData(Http2StreamId stream_id); + + // Resumes a stream that was previously blocked. Returns true on success. + bool ResumeStream(Http2StreamId stream_id); + + // Returns the peer's outstanding stream receive window for the given stream. + int GetStreamSendWindowSize(Http2StreamId stream_id) const; + + // Returns the current upper bound on the flow control receive window for this + // stream. + int GetStreamReceiveWindowLimit(Http2StreamId stream_id) const; + + // Returns the outstanding stream receive window, or -1 if the stream does not + // exist. + int GetStreamReceiveWindowSize(Http2StreamId stream_id) const; + + // Returns the outstanding connection receive window. + int GetReceiveWindowSize() const; + + // Returns the size of the HPACK encoder's dynamic table, including the + // per-entry overhead from the specification. + int GetHpackEncoderDynamicTableSize() const; + + // Returns the maximum capacity of the HPACK encoder's dynamic table. + int GetHpackEncoderDynamicTableCapacity() const; + + // Returns the size of the HPACK decoder's dynamic table, including the + // per-entry overhead from the specification. + int GetHpackDecoderDynamicTableSize() const; + + // Returns the size of the HPACK decoder's most recently applied size limit. + int GetHpackDecoderSizeLimit() const; // From Http2Session. - ssize_t ProcessBytes(absl::string_view bytes) override; + int64_t ProcessBytes(absl::string_view bytes) override; int Consume(Http2StreamId stream_id, size_t num_bytes) override; - bool want_read() const override { return !received_goaway_; } - bool want_write() const override { - return !frames_.empty() || !serialized_prefix_.empty(); + bool want_read() const override { + return !received_goaway_ && !decoder_.HasError(); } - int GetRemoteWindowSize() const override { - return peer_window_; + bool want_write() const override { + return !frames_.empty() || !buffered_data_.empty() || + write_scheduler_.HasReadyStreams() || !connection_metadata_.empty(); } + int GetRemoteWindowSize() const override { return connection_send_window_; } // From SpdyFramerVisitorInterface void OnError(http2::Http2DecoderAdapter::SpdyFramerError error, @@ -72,7 +156,7 @@ class OgHttp2Session : public Http2Session, void OnPing(spdy::SpdyPingId unique_id, bool is_ack) override; void OnGoAway(spdy::SpdyStreamId last_accepted_stream_id, spdy::SpdyErrorCode error_code) override; - bool OnGoAwayFrameData(const char* goaway_data, size_t len); + bool OnGoAwayFrameData(const char* goaway_data, size_t len) override; void OnHeaders(spdy::SpdyStreamId stream_id, bool has_priority, int weight, @@ -86,10 +170,9 @@ class OgHttp2Session : public Http2Session, spdy::SpdyStreamId promised_stream_id, bool end) override; void OnContinuation(spdy::SpdyStreamId stream_id, bool end) override; - void OnAltSvc(spdy::SpdyStreamId /*stream_id*/, - absl::string_view /*origin*/, + void OnAltSvc(spdy::SpdyStreamId /*stream_id*/, absl::string_view /*origin*/, const spdy::SpdyAltSvcWireFormat:: - AlternativeServiceVector& /*altsvc_vector*/); + AlternativeServiceVector& /*altsvc_vector*/) override; void OnPriority(spdy::SpdyStreamId stream_id, spdy::SpdyStreamId parent_stream_id, int weight, @@ -99,40 +182,260 @@ class OgHttp2Session : public Http2Session, bool OnUnknownFrame(spdy::SpdyStreamId stream_id, uint8_t frame_type) override; + // Invoked when header processing encounters an invalid or otherwise + // problematic header. + void OnHeaderStatus(Http2StreamId stream_id, + Http2VisitorInterface::OnHeaderResult result); + + // Returns true if a recognized extension frame is received. + bool OnFrameHeader(spdy::SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + + // Handles the payload for a recognized extension frame. + void OnFramePayload(const char* data, size_t len) override; + private: - struct StreamState { + using MetadataSequence = std::vector>; + + struct QUICHE_EXPORT_PRIVATE StreamState { + StreamState(int32_t stream_receive_window, + WindowManager::WindowUpdateListener listener) + : window_manager(stream_receive_window, std::move(listener)) {} + WindowManager window_manager; - int32_t send_window = 65535; + std::unique_ptr outbound_body; + MetadataSequence outbound_metadata; + std::unique_ptr trailers; + void* user_data = nullptr; + int32_t send_window = kInitialFlowControlWindowSize; + absl::optional received_header_type; bool half_closed_local = false; bool half_closed_remote = false; + // Indicates that `outbound_body` temporarily cannot produce data. + bool data_deferred = false; }; + using StreamStateMap = absl::flat_hash_map; - class PassthroughHeadersHandler : public spdy::SpdyHeadersHandlerInterface { + struct QUICHE_EXPORT_PRIVATE PendingStreamState { + Http2StreamId stream_id; + spdy::SpdyHeaderBlock headers; + std::unique_ptr data_source; + void* user_data = nullptr; + }; + + class QUICHE_EXPORT_PRIVATE PassthroughHeadersHandler + : public spdy::SpdyHeadersHandlerInterface { public: - explicit PassthroughHeadersHandler(Http2VisitorInterface& visitor) - : visitor_(visitor) {} - void set_stream_id(Http2StreamId stream_id) { stream_id_ = stream_id; } + explicit PassthroughHeadersHandler(OgHttp2Session& session, + Http2VisitorInterface& visitor) + : session_(session), visitor_(visitor) {} + + void set_stream_id(Http2StreamId stream_id) { + stream_id_ = stream_id; + result_ = Http2VisitorInterface::HEADER_OK; + } + + void set_frame_contains_fin() { frame_contains_fin_ = true; } + void set_header_type(HeaderType type) { type_ = type; } + HeaderType header_type() const { return type_; } + void OnHeaderBlockStart() override; void OnHeader(absl::string_view key, absl::string_view value) override; void OnHeaderBlockEnd(size_t /* uncompressed_header_bytes */, size_t /* compressed_header_bytes */) override; + absl::string_view status_header() { + QUICHE_DCHECK(type_ == HeaderType::RESPONSE || + type_ == HeaderType::RESPONSE_100); + return validator_.status_header(); + } private: + OgHttp2Session& session_; Http2VisitorInterface& visitor_; Http2StreamId stream_id_ = 0; + Http2VisitorInterface::OnHeaderResult result_ = + Http2VisitorInterface::HEADER_OK; + // Validates header blocks according to the HTTP/2 specification. + HeaderValidator validator_; + HeaderType type_ = HeaderType::RESPONSE; + bool frame_contains_fin_ = false; }; + // Queues the connection preface, if not already done. + void MaybeSetupPreface(); + + // Gets the settings to be sent in the initial SETTINGS frame sent as part of + // the connection preface. + std::vector GetInitialSettings() const; + + // Prepares and returns a SETTINGS frame with the given `settings`. + std::unique_ptr PrepareSettingsFrame( + absl::Span settings); + + void SendWindowUpdate(Http2StreamId stream_id, size_t update_delta); + + enum class SendResult { + // All data was flushed. + SEND_OK, + // Not all data was flushed (due to flow control or TCP back pressure). + SEND_BLOCKED, + // An error occurred while sending data. + SEND_ERROR, + }; + + // Sends the buffered connection preface or serialized frame data, if any. + SendResult MaybeSendBufferedData(); + + // Serializes and sends queued frames. + SendResult SendQueuedFrames(); + + void AfterFrameSent(uint8_t frame_type, uint32_t stream_id, + size_t payload_length, uint8_t flags, + uint32_t error_code); + + // Writes DATA frames for stream `stream_id`. + SendResult WriteForStream(Http2StreamId stream_id); + + SendResult SendMetadata(Http2StreamId stream_id, MetadataSequence& sequence); + + void SendHeaders(Http2StreamId stream_id, spdy::SpdyHeaderBlock headers, + bool end_stream); + + void SendTrailers(Http2StreamId stream_id, spdy::SpdyHeaderBlock trailers); + + // Encapsulates the RST_STREAM NO_ERROR behavior described in RFC 7540 + // Section 8.1. + void MaybeFinWithRstStream(StreamStateMap::iterator iter); + + // Performs flow control accounting for data sent by the peer. + void MarkDataBuffered(Http2StreamId stream_id, size_t bytes); + + // Creates a stream for `stream_id` if not already present and returns an + // iterator pointing to it. + StreamStateMap::iterator CreateStream(Http2StreamId stream_id); + + // Creates a stream for `stream_id`, stores the `data_source` and `user_data` + // in the stream state, and sends the `headers`. + void StartRequest(Http2StreamId stream_id, spdy::SpdyHeaderBlock headers, + std::unique_ptr data_source, + void* user_data); + + // Closes the given `stream_id` with the given `error_code`. + void CloseStream(Http2StreamId stream_id, Http2ErrorCode error_code); + + // Calculates the next expected header type for a stream in a given state. + HeaderType NextHeaderType(absl::optional current_type); + + // Returns true if the session can create a new stream. + bool CanCreateStream() const; + + // Informs the visitor of the connection `error` and stops processing on the + // connection. If server-side, also sends a GOAWAY with `error_code`. + void LatchErrorAndNotify(Http2ErrorCode error_code, + Http2VisitorInterface::ConnectionError error); + + void CloseStreamIfReady(uint8_t frame_type, uint32_t stream_id); + + // Receives events when inbound frames are parsed. Http2VisitorInterface& visitor_; + + // Forwards received events to the session if it can accept them. + EventForwarder event_forwarder_; + + // Logs received frames when enabled. + Http2TraceLogger receive_logger_; + // Logs sent frames when enabled. + Http2FrameLogger send_logger_; + + // Encodes outbound frames. spdy::SpdyFramer framer_{spdy::SpdyFramer::ENABLE_COMPRESSION}; + + // Decodes inbound frames. http2::Http2DecoderAdapter decoder_; - absl::flat_hash_map stream_map_; + + // Maintains the state of active streams known to this session. + StreamStateMap stream_map_; + + // Maintains the state of pending streams known to this session. A pending + // stream is kept in this list until it can be created while complying with + // `max_outbound_concurrent_streams_`. + std::list pending_streams_; + + // The queue of outbound frames. std::list> frames_; + // Buffered data (connection preface, serialized frames) that has not yet been + // sent. + std::string buffered_data_; + + // Maintains the set of streams ready to write data to the peer. + using WriteScheduler = PriorityWriteScheduler; + WriteScheduler write_scheduler_; + + // Stores the queue of callbacks to invoke upon receiving SETTINGS acks. At + // most one callback is invoked for each SETTINGS ack. + using SettingsAckCallback = std::function; + std::list settings_ack_callbacks_; + + // Delivers header name-value pairs to the visitor. PassthroughHeadersHandler headers_handler_; - std::string serialized_prefix_; + + // Ignores header data, e.g., for an unknown or rejected stream. + spdy::NoOpHeadersHandler noop_headers_handler_; + + // Tracks the remaining client connection preface, in the case of a server + // session. absl::string_view remaining_preface_; - int peer_window_ = 65535; + + WindowManager connection_window_manager_; + + absl::flat_hash_set streams_reset_; + absl::flat_hash_map queued_frames_; + + MetadataSequence connection_metadata_; + + Http2StreamId next_stream_id_ = 1; + // The highest received stream ID is the highest stream ID in any frame read + // from the peer. The highest processed stream ID is the highest stream ID for + // which this endpoint created a stream in the stream map. + Http2StreamId highest_received_stream_id_ = 0; + Http2StreamId highest_processed_stream_id_ = 0; + Http2StreamId metadata_stream_id_ = 0; + size_t metadata_length_ = 0; + int32_t connection_send_window_ = kInitialFlowControlWindowSize; + // The initial flow control receive window size for any newly created streams. + int32_t stream_receive_window_limit_ = kInitialFlowControlWindowSize; + uint32_t max_frame_payload_ = 16384u; + // The maximum number of concurrent streams that this connection can open to + // its peer and allow from its peer, respectively. Although the initial value + // is unlimited, the spec encourages a value of at least 100. We limit + // ourselves to opening 100 until told otherwise by the peer and allow an + // unlimited number from the peer until updated from SETTINGS we send. + uint32_t max_outbound_concurrent_streams_ = 100u; + uint32_t pending_max_inbound_concurrent_streams_ = + std::numeric_limits::max(); + uint32_t max_inbound_concurrent_streams_ = + std::numeric_limits::max(); Options options_; + + // The HPACK encoder header table capacity that will be applied when + // acking SETTINGS from the peer. Only contains a value if the peer advertises + // a larger table capacity than currently used; a smaller value can safely be + // applied immediately upon receipt. + absl::optional encoder_header_table_capacity_when_acking_; + bool received_goaway_ = false; + bool queued_preface_ = false; + bool peer_supports_metadata_ = false; + bool end_metadata_ = false; + + // Recursion guard for ProcessBytes(). + bool processing_bytes_ = false; + // Recursion guard for Send(). + bool sending_ = false; + + // Replace this with a stream ID, for multiple GOAWAY support. + bool queued_goaway_ = false; + bool latched_error_ = false; }; } // namespace adapter diff --git a/gquiche/http2/adapter/oghttp2_session_test.cc b/gquiche/http2/adapter/oghttp2_session_test.cc index 0302faed..7b7e6c7f 100644 --- a/gquiche/http2/adapter/oghttp2_session_test.cc +++ b/gquiche/http2/adapter/oghttp2_session_test.cc @@ -2,11 +2,30 @@ #include "gquiche/http2/adapter/mock_http2_visitor.h" #include "gquiche/http2/adapter/test_frame_sequence.h" +#include "gquiche/http2/adapter/test_utils.h" #include "gquiche/common/platform/api/quiche_test.h" namespace http2 { namespace adapter { namespace test { +namespace { + +using spdy::SpdyFrameType; +using testing::_; + +enum FrameType { + DATA, + HEADERS, + PRIORITY, + RST_STREAM, + SETTINGS, + PUSH_PROMISE, + PING, + GOAWAY, + WINDOW_UPDATE, +}; + +} // namespace TEST(OgHttp2SessionTest, ClientConstruction) { testing::StrictMock visitor; @@ -14,7 +33,9 @@ TEST(OgHttp2SessionTest, ClientConstruction) { visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); EXPECT_TRUE(session.want_read()); EXPECT_FALSE(session.want_write()); - EXPECT_EQ(session.GetRemoteWindowSize(), kDefaultInitialStreamWindowSize); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); + EXPECT_FALSE(session.IsServerSession()); + EXPECT_EQ(0, session.GetHighestReceivedStreamId()); } TEST(OgHttp2SessionTest, ClientHandlesFrames) { @@ -30,45 +51,450 @@ TEST(OgHttp2SessionTest, ClientHandlesFrames) { testing::InSequence s; // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); EXPECT_CALL(visitor, OnSettingsStart()); EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); - const ssize_t initial_result = session.ProcessBytes(initial_frames); - EXPECT_EQ(initial_frames.size(), initial_result); + const int64_t initial_result = session.ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); EXPECT_EQ(session.GetRemoteWindowSize(), - kDefaultInitialStreamWindowSize + 1000); + kInitialFlowControlWindowSize + 1000); + EXPECT_EQ(0, session.GetHighestReceivedStreamId()); + + // Connection has not yet received any data. + EXPECT_EQ(kInitialFlowControlWindowSize, session.GetReceiveWindowSize()); + + EXPECT_EQ(0, session.GetHpackDecoderDynamicTableSize()); - // Should OgHttp2Session require that streams 1 and 3 have been created? + // Submit a request to ensure the first stream is created. + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = absl::make_unique(visitor, true); + body1->AppendPayload("This is an example request body."); + body1->EndData(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_EQ(stream_id, 1); + + // Submit another request to ensure the next stream is created. + int stream_id2 = + session.SubmitRequest(ToHeaders({{":method", "GET"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + EXPECT_EQ(stream_id2, 3); const std::string stream_frames = TestFrameSequence() - .Headers(1, + .Headers(stream_id, {{":status", "200"}, {"server", "my-fake-server"}, {"date", "Tue, 6 Apr 2021 12:54:01 GMT"}}, /*fin=*/false) - .Data(1, "This is the response body.") - .RstStream(3, Http2ErrorCode::INTERNAL_ERROR) + .Data(stream_id, "This is the response body.") + .RstStream(stream_id2, Http2ErrorCode::INTERNAL_ERROR) .GoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "calm down!!") .Serialize(); - EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); - EXPECT_CALL(visitor, OnHeaderForStream(1, ":status", "200")); - EXPECT_CALL(visitor, OnHeaderForStream(1, "server", "my-fake-server")); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, _, HEADERS, 4)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, ":status", "200")); EXPECT_CALL(visitor, - OnHeaderForStream(1, "date", "Tue, 6 Apr 2021 12:54:01 GMT")); - EXPECT_CALL(visitor, OnEndHeadersForStream(1)); - EXPECT_CALL(visitor, OnBeginDataForStream(1, 26)); - EXPECT_CALL(visitor, OnDataForStream(1, "This is the response body.")); - EXPECT_CALL(visitor, OnRstStream(3, Http2ErrorCode::INTERNAL_ERROR)); - EXPECT_CALL(visitor, OnAbortStream(3, Http2ErrorCode::INTERNAL_ERROR)); + OnHeaderForStream(stream_id, "server", "my-fake-server")); + EXPECT_CALL(visitor, OnHeaderForStream(stream_id, "date", + "Tue, 6 Apr 2021 12:54:01 GMT")); + EXPECT_CALL(visitor, OnEndHeadersForStream(stream_id)); + EXPECT_CALL(visitor, OnFrameHeader(stream_id, 26, DATA, 0)); + EXPECT_CALL(visitor, OnBeginDataForStream(stream_id, 26)); + EXPECT_CALL(visitor, + OnDataForStream(stream_id, "This is the response body.")); + EXPECT_CALL(visitor, OnFrameHeader(stream_id2, 4, RST_STREAM, 0)); + EXPECT_CALL(visitor, OnRstStream(stream_id2, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor, + OnCloseStream(stream_id2, Http2ErrorCode::INTERNAL_ERROR)); + EXPECT_CALL(visitor, OnFrameHeader(0, 19, GOAWAY, 0)); EXPECT_CALL(visitor, OnGoAway(5, Http2ErrorCode::ENHANCE_YOUR_CALM, "")); - const ssize_t stream_result = session.ProcessBytes(stream_frames); - EXPECT_EQ(stream_frames.size(), stream_result); + const int64_t stream_result = session.ProcessBytes(stream_frames); + EXPECT_EQ(stream_frames.size(), static_cast(stream_result)); + EXPECT_EQ(stream_id2, session.GetHighestReceivedStreamId()); + + // The first stream is active and has received some data. + EXPECT_GT(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowSize(stream_id)); + // Connection receive window is equivalent to the first stream's. + EXPECT_EQ(session.GetReceiveWindowSize(), + session.GetStreamReceiveWindowSize(stream_id)); + // Receive window upper bound is still the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowLimit(stream_id)); + + EXPECT_GT(session.GetHpackDecoderDynamicTableSize(), 0); +} + +// Verifies that a client session enqueues initial SETTINGS if Send() is called +// before any frames are explicitly queued. +TEST(OgHttp2SessionTest, ClientEnqueuesSettingsOnSend) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + EXPECT_FALSE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); +} + +// Verifies that a client session enqueues initial SETTINGS before whatever +// frame type is passed to the first invocation of EnqueueFrame(). +TEST(OgHttp2SessionTest, ClientEnqueuesSettingsBeforeOtherFrame) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(absl::make_unique(42)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, 8, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, 8, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING})); +} + +// Verifies that if the first call to EnqueueFrame() passes a SETTINGS frame, +// the client session will not enqueue an additional SETTINGS frame. +TEST(OgHttp2SessionTest, ClientEnqueuesSettingsOnce) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(absl::make_unique()); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2SessionTest, ClientSubmitRequest) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + + EXPECT_FALSE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + // Even though the user has not queued any frames for the session, it should + // still send the connection preface. + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + // Initial SETTINGS. + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + const std::string initial_frames = + TestFrameSequence().ServerPreface().Serialize(); + testing::InSequence s; + + // Server preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + + const int64_t initial_result = session.ProcessBytes(initial_frames); + EXPECT_EQ(initial_frames.size(), static_cast(initial_result)); + + // Session will want to write a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_EQ(0, session.GetHpackEncoderDynamicTableSize()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = absl::make_unique(visitor, true); + body1->AppendPayload("This is an example request body."); + body1->EndData(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS, + spdy::SpdyFrameType::DATA})); + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(session.GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); + EXPECT_GT(session.GetStreamSendWindowSize(stream_id), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(-1, session.GetStreamSendWindowSize(stream_id + 2)); + + EXPECT_GT(session.GetHpackEncoderDynamicTableSize(), 0); + + stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/two"}}), + nullptr, nullptr); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + const char* kSentinel2 = "arbitrary pointer 2"; + EXPECT_EQ(nullptr, session.GetStreamUserData(stream_id)); + session.SetStreamUserData(stream_id, const_cast(kSentinel2)); + EXPECT_EQ(kSentinel2, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x5, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({spdy::SpdyFrameType::HEADERS})); + + // No data was sent (just HEADERS), so the remaining send window size should + // still be the default. + EXPECT_EQ(session.GetStreamSendWindowSize(stream_id), + kInitialFlowControlWindowSize); +} + +// This test exercises the case where the client request body source is read +// blocked. +TEST(OgHttp2SessionTest, ClientSubmitRequestWithReadBlock) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + EXPECT_FALSE(session.want_write()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = absl::make_unique(visitor, true); + TestDataFrameSource* body_ref = body1.get(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + // No data frame, as body1 was read blocked. + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + body_ref->AppendPayload("This is an example request body."); + body_ref->EndData(); + EXPECT_TRUE(session.ResumeStream(stream_id)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(session.ResumeStream(stream_id)); + EXPECT_FALSE(session.want_write()); +} + +// This test exercises the case where the client request body source is read +// blocked, then ends with an empty DATA frame. +TEST(OgHttp2SessionTest, ClientSubmitRequestEmptyDataWithFin) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + EXPECT_FALSE(session.want_write()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = absl::make_unique(visitor, true); + TestDataFrameSource* body_ref = body1.get(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS})); + // No data frame, as body1 was read blocked. + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + body_ref->EndData(); + EXPECT_TRUE(session.ResumeStream(stream_id)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, 0, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); + + // Stream data is done, so this stream cannot be resumed. + EXPECT_FALSE(session.ResumeStream(stream_id)); + EXPECT_FALSE(session.want_write()); +} + +// This test exercises the case where the connection to the peer is write +// blocked. +TEST(OgHttp2SessionTest, ClientSubmitRequestWithWriteBlock) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + EXPECT_FALSE(session.want_write()); + + const char* kSentinel1 = "arbitrary pointer 1"; + auto body1 = absl::make_unique(visitor, true); + body1->AppendPayload("This is an example request body."); + body1->EndData(); + int stream_id = + session.SubmitRequest(ToHeaders({{":method", "POST"}, + {":scheme", "http"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}), + std::move(body1), const_cast(kSentinel1)); + EXPECT_GT(stream_id, 0); + EXPECT_TRUE(session.want_write()); + EXPECT_EQ(kSentinel1, session.GetStreamUserData(stream_id)); + visitor.set_is_write_blocked(true); + int result = session.Send(); + EXPECT_EQ(0, result); + + EXPECT_THAT(visitor.data(), testing::IsEmpty()); + EXPECT_TRUE(session.want_write()); + visitor.set_is_write_blocked(false); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, stream_id, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, stream_id, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, stream_id, _, 0x1, 0)); + + result = session.Send(); + EXPECT_EQ(0, result); + + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::HEADERS, + SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); +} + +TEST(OgHttp2SessionTest, ClientStartShutdown) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kClient}); + + EXPECT_FALSE(session.want_write()); + + // No-op (except for logging) for a client implementation. + session.StartGracefulShutdown(); + EXPECT_FALSE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + + absl::string_view serialized = visitor.data(); + EXPECT_THAT(serialized, + testing::StartsWith(spdy::kHttp2ConnectionHeaderPrefix)); + serialized.remove_prefix(strlen(spdy::kHttp2ConnectionHeaderPrefix)); + EXPECT_THAT(serialized, EqualsFrames({SpdyFrameType::SETTINGS})); } TEST(OgHttp2SessionTest, ServerConstruction) { @@ -77,14 +503,18 @@ TEST(OgHttp2SessionTest, ServerConstruction) { visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); EXPECT_TRUE(session.want_read()); EXPECT_FALSE(session.want_write()); - EXPECT_EQ(session.GetRemoteWindowSize(), kDefaultInitialStreamWindowSize); + EXPECT_EQ(session.GetRemoteWindowSize(), kInitialFlowControlWindowSize); + EXPECT_TRUE(session.IsServerSession()); + EXPECT_EQ(0, session.GetHighestReceivedStreamId()); } TEST(OgHttp2SessionTest, ServerHandlesFrames) { - testing::StrictMock visitor; + DataSavingVisitor visitor; OgHttp2Session session( visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + EXPECT_EQ(0, session.GetHpackDecoderDynamicTableSize()); + const std::string frames = TestFrameSequence() .ClientPreface() .Ping(42) @@ -108,21 +538,34 @@ TEST(OgHttp2SessionTest, ServerHandlesFrames) { .Serialize(); testing::InSequence s; + const char* kSentinel1 = "arbitrary pointer 1"; + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); EXPECT_CALL(visitor, OnSettingsStart()); EXPECT_CALL(visitor, OnSettingsEnd()); + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); EXPECT_CALL(visitor, OnPing(42, false)); + EXPECT_CALL(visitor, OnFrameHeader(0, 4, WINDOW_UPDATE, 0)); EXPECT_CALL(visitor, OnWindowUpdate(0, 1000)); + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 4)); EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "POST")); EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); - EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&session, kSentinel1]() { + session.SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnFrameHeader(1, 4, WINDOW_UPDATE, 0)); EXPECT_CALL(visitor, OnWindowUpdate(1, 2000)); + EXPECT_CALL(visitor, OnFrameHeader(1, 25, DATA, 0)); EXPECT_CALL(visitor, OnBeginDataForStream(1, 25)); EXPECT_CALL(visitor, OnDataForStream(1, "This is the request body.")); + EXPECT_CALL(visitor, OnFrameHeader(3, _, HEADERS, 5)); EXPECT_CALL(visitor, OnBeginHeadersForStream(3)); EXPECT_CALL(visitor, OnHeaderForStream(3, ":method", "GET")); EXPECT_CALL(visitor, OnHeaderForStream(3, ":scheme", "http")); @@ -130,15 +573,418 @@ TEST(OgHttp2SessionTest, ServerHandlesFrames) { EXPECT_CALL(visitor, OnHeaderForStream(3, ":path", "/this/is/request/two")); EXPECT_CALL(visitor, OnEndHeadersForStream(3)); EXPECT_CALL(visitor, OnEndStream(3)); + EXPECT_CALL(visitor, OnFrameHeader(3, 4, RST_STREAM, 0)); EXPECT_CALL(visitor, OnRstStream(3, Http2ErrorCode::CANCEL)); - EXPECT_CALL(visitor, OnAbortStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnCloseStream(3, Http2ErrorCode::CANCEL)); + EXPECT_CALL(visitor, OnFrameHeader(0, 8, PING, 0)); EXPECT_CALL(visitor, OnPing(47, false)); - const ssize_t result = session.ProcessBytes(frames); - EXPECT_EQ(frames.size(), result); + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_EQ(kSentinel1, session.GetStreamUserData(1)); + + // The first stream is active and has received some data. + EXPECT_GT(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowSize(1)); + // Connection receive window is equivalent to the first stream's. + EXPECT_EQ(session.GetReceiveWindowSize(), + session.GetStreamReceiveWindowSize(1)); + // Receive window upper bound is still the initial value. + EXPECT_EQ(kInitialFlowControlWindowSize, + session.GetStreamReceiveWindowLimit(1)); + + EXPECT_GT(session.GetHpackDecoderDynamicTableSize(), 0); + + // It should no longer be possible to set user data on a closed stream. + const char* kSentinel3 = "another arbitrary pointer"; + session.SetStreamUserData(3, const_cast(kSentinel3)); + EXPECT_EQ(nullptr, session.GetStreamUserData(3)); EXPECT_EQ(session.GetRemoteWindowSize(), - kDefaultInitialStreamWindowSize + 1000); + kInitialFlowControlWindowSize + 1000); + EXPECT_EQ(3, session.GetHighestReceivedStreamId()); + + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x1, 0)); + + // Some bytes should have been serialized. + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + // Initial SETTINGS, SETTINGS ack, and PING acks (for PING IDs 42 and 47). + EXPECT_THAT(visitor.data(), + EqualsFrames( + {spdy::SpdyFrameType::SETTINGS, spdy::SpdyFrameType::SETTINGS, + spdy::SpdyFrameType::PING, spdy::SpdyFrameType::PING})); +} + +// Verifies that a server session enqueues initial SETTINGS before whatever +// frame type is passed to the first invocation of EnqueueFrame(). +TEST(OgHttp2SessionTest, ServerEnqueuesSettingsBeforeOtherFrame) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(absl::make_unique(42)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(PING, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(PING, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::PING})); +} + +// Verifies that if the first call to EnqueueFrame() passes a SETTINGS frame, +// the server session will not enqueue an additional SETTINGS frame. +TEST(OgHttp2SessionTest, ServerEnqueuesSettingsOnce) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + EXPECT_FALSE(session.want_write()); + session.EnqueueFrame(absl::make_unique()); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::SETTINGS})); +} + +TEST(OgHttp2SessionTest, ServerSubmitResponse) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + + EXPECT_FALSE(session.want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + const char* kSentinel1 = "arbitrary pointer 1"; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)) + .WillOnce(testing::InvokeWithoutArgs([&session, kSentinel1]() { + session.SetStreamUserData(1, const_cast(kSentinel1)); + return true; + })); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + EXPECT_EQ(1, session.GetHighestReceivedStreamId()); + + EXPECT_EQ(0, session.GetHpackEncoderDynamicTableSize()); + + // Server will want to send initial SETTINGS, and a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(session.want_write()); + // A data fin is not sent so that the stream remains open, and the flow + // control state can be verified. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload("This is an example response body."); + int submit_result = session.SubmitResponse( + 1, + ToHeaders({{":status", "404"}, + {"x-comment", "I have no idea what you're talking about."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(session.want_write()); + + // Stream user data should have been set successfully after receiving headers. + EXPECT_EQ(kSentinel1, session.GetStreamUserData(1)); + session.SetStreamUserData(1, nullptr); + EXPECT_EQ(nullptr, session.GetStreamUserData(1)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + EXPECT_FALSE(session.want_write()); + + // Some data was sent, so the remaining send window size should be less than + // the default. + EXPECT_LT(session.GetStreamSendWindowSize(1), kInitialFlowControlWindowSize); + EXPECT_GT(session.GetStreamSendWindowSize(1), 0); + // Send window for a nonexistent stream is not available. + EXPECT_EQ(session.GetStreamSendWindowSize(3), -1); + + EXPECT_GT(session.GetHpackEncoderDynamicTableSize(), 0); +} + +TEST(OgHttp2SessionTest, ServerStartShutdown) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + + EXPECT_FALSE(session.want_write()); + + session.StartGracefulShutdown(); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); +} + +TEST(OgHttp2SessionTest, ServerStartShutdownAfterGoaway) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + + EXPECT_FALSE(session.want_write()); + + auto goaway = absl::make_unique( + 1, spdy::ERROR_CODE_NO_ERROR, "and don't come back!"); + session.EnqueueFrame(std::move(goaway)); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(GOAWAY, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(GOAWAY, 0, _, 0x0, 0)); + + int result = session.Send(); + EXPECT_EQ(0, result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::GOAWAY})); + + // No-op, since a GOAWAY has previously been enqueued. + session.StartGracefulShutdown(); + EXPECT_FALSE(session.want_write()); +} + +// Tests the case where the server queues trailers after the data stream is +// exhausted. +TEST(OgHttp2SessionTest, ServerSendsTrailers) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + + EXPECT_FALSE(session.want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + // Server will want to send initial SETTINGS, and a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(session.want_write()); + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload("This is an example response body."); + body1->EndData(); + int submit_result = session.SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA})); + visitor.Clear(); + EXPECT_FALSE(session.want_write()); + + // The body source has been exhausted by the call to Send() above. + int trailer_result = session.SubmitTrailer( + 1, ToHeaders({{"final-status", "a-ok"}, + {"x-comment", "trailers sure are cool"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), EqualsFrames({SpdyFrameType::HEADERS})); +} + +// Tests the case where the server queues trailers immediately after headers and +// data, and before any writes have taken place. +TEST(OgHttp2SessionTest, ServerQueuesTrailersWithResponse) { + DataSavingVisitor visitor; + OgHttp2Session session( + visitor, OgHttp2Session::Options{.perspective = Perspective::kServer}); + + EXPECT_FALSE(session.want_write()); + + const std::string frames = TestFrameSequence() + .ClientPreface() + .Headers(1, + {{":method", "GET"}, + {":scheme", "https"}, + {":authority", "example.com"}, + {":path", "/this/is/request/one"}}, + /*fin=*/true) + .Serialize(); + testing::InSequence s; + + // Client preface (empty SETTINGS) + EXPECT_CALL(visitor, OnFrameHeader(0, 0, SETTINGS, 0)); + EXPECT_CALL(visitor, OnSettingsStart()); + EXPECT_CALL(visitor, OnSettingsEnd()); + // Stream 1 + EXPECT_CALL(visitor, OnFrameHeader(1, _, HEADERS, 5)); + EXPECT_CALL(visitor, OnBeginHeadersForStream(1)); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":method", "GET")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":scheme", "https")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":authority", "example.com")); + EXPECT_CALL(visitor, OnHeaderForStream(1, ":path", "/this/is/request/one")); + EXPECT_CALL(visitor, OnEndHeadersForStream(1)); + EXPECT_CALL(visitor, OnEndStream(1)); + + const int64_t result = session.ProcessBytes(frames); + EXPECT_EQ(frames.size(), static_cast(result)); + + // Server will want to send initial SETTINGS, and a SETTINGS ack. + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x0)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x0, 0)); + EXPECT_CALL(visitor, OnBeforeFrameSent(SETTINGS, 0, _, 0x1)); + EXPECT_CALL(visitor, OnFrameSent(SETTINGS, 0, _, 0x1, 0)); + + int send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::SETTINGS, SpdyFrameType::SETTINGS})); + visitor.Clear(); + + EXPECT_FALSE(session.want_write()); + + // The body source must indicate that the end of the body is not the end of + // the stream. + auto body1 = absl::make_unique(visitor, false); + body1->AppendPayload("This is an example response body."); + body1->EndData(); + int submit_result = session.SubmitResponse( + 1, ToHeaders({{":status", "200"}, {"x-comment", "Sure, sounds good."}}), + std::move(body1)); + EXPECT_EQ(submit_result, 0); + EXPECT_TRUE(session.want_write()); + // There has not been a call to Send() yet, so neither headers nor body have + // been written. + int trailer_result = session.SubmitTrailer( + 1, ToHeaders({{"final-status", "a-ok"}, + {"x-comment", "trailers sure are cool"}})); + ASSERT_EQ(trailer_result, 0); + EXPECT_TRUE(session.want_write()); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x4)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x4, 0)); + EXPECT_CALL(visitor, OnFrameSent(DATA, 1, _, 0x0, 0)); + + EXPECT_CALL(visitor, OnBeforeFrameSent(HEADERS, 1, _, 0x5)); + EXPECT_CALL(visitor, OnFrameSent(HEADERS, 1, _, 0x5, 0)); + EXPECT_CALL(visitor, OnCloseStream(1, Http2ErrorCode::HTTP2_NO_ERROR)); + + send_result = session.Send(); + EXPECT_EQ(0, send_result); + EXPECT_THAT(visitor.data(), + EqualsFrames({SpdyFrameType::HEADERS, SpdyFrameType::DATA, + SpdyFrameType::HEADERS})); } } // namespace test diff --git a/gquiche/http2/adapter/oghttp2_util.cc b/gquiche/http2/adapter/oghttp2_util.cc new file mode 100644 index 00000000..9cfef194 --- /dev/null +++ b/gquiche/http2/adapter/oghttp2_util.cc @@ -0,0 +1,17 @@ +#include "gquiche/http2/adapter/oghttp2_util.h" + +namespace http2 { +namespace adapter { + +spdy::SpdyHeaderBlock ToHeaderBlock(absl::Span headers) { + spdy::SpdyHeaderBlock block; + for (const Header& header : headers) { + absl::string_view name = GetStringView(header.first).first; + absl::string_view value = GetStringView(header.second).first; + block[name] = value; + } + return block; +} + +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/oghttp2_util.h b/gquiche/http2/adapter/oghttp2_util.h new file mode 100644 index 00000000..754c38ec --- /dev/null +++ b/gquiche/http2/adapter/oghttp2_util.h @@ -0,0 +1,18 @@ +#ifndef QUICHE_HTTP2_ADAPTER_OGHTTP2_UTIL_H_ +#define QUICHE_HTTP2_ADAPTER_OGHTTP2_UTIL_H_ + +#include "absl/types/span.h" +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/spdy/core/spdy_header_block.h" + +namespace http2 { +namespace adapter { + +QUICHE_EXPORT_PRIVATE spdy::SpdyHeaderBlock ToHeaderBlock( + absl::Span headers); + +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_OGHTTP2_UTIL_H_ diff --git a/gquiche/http2/adapter/recording_http2_visitor.cc b/gquiche/http2/adapter/recording_http2_visitor.cc new file mode 100644 index 00000000..2adbc893 --- /dev/null +++ b/gquiche/http2/adapter/recording_http2_visitor.cc @@ -0,0 +1,174 @@ +#include "gquiche/http2/adapter/recording_http2_visitor.h" + +#include "absl/strings/str_format.h" +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/http2_util.h" + +namespace http2 { +namespace adapter { +namespace test { + +int64_t RecordingHttp2Visitor::OnReadyToSend(absl::string_view serialized) { + events_.push_back(absl::StrFormat("OnReadyToSend %d", serialized.size())); + return serialized.size(); +} + +void RecordingHttp2Visitor::OnConnectionError(ConnectionError error) { + events_.push_back( + absl::StrFormat("OnConnectionError %s", ConnectionErrorToString(error))); +} + +bool RecordingHttp2Visitor::OnFrameHeader(Http2StreamId stream_id, + size_t length, uint8_t type, + uint8_t flags) { + events_.push_back(absl::StrFormat("OnFrameHeader %d %d %d %d", stream_id, + length, type, flags)); + return true; +} + +void RecordingHttp2Visitor::OnSettingsStart() { + events_.push_back("OnSettingsStart"); +} + +void RecordingHttp2Visitor::OnSetting(Http2Setting setting) { + events_.push_back(absl::StrFormat( + "OnSetting %s %d", Http2SettingsIdToString(setting.id), setting.value)); +} + +void RecordingHttp2Visitor::OnSettingsEnd() { + events_.push_back("OnSettingsEnd"); +} + +void RecordingHttp2Visitor::OnSettingsAck() { + events_.push_back("OnSettingsAck"); +} + +bool RecordingHttp2Visitor::OnBeginHeadersForStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnBeginHeadersForStream %d", stream_id)); + return true; +} + +Http2VisitorInterface::OnHeaderResult RecordingHttp2Visitor::OnHeaderForStream( + Http2StreamId stream_id, absl::string_view name, absl::string_view value) { + events_.push_back( + absl::StrFormat("OnHeaderForStream %d %s %s", stream_id, name, value)); + return HEADER_OK; +} + +bool RecordingHttp2Visitor::OnEndHeadersForStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnEndHeadersForStream %d", stream_id)); + return true; +} + +bool RecordingHttp2Visitor::OnBeginDataForStream(Http2StreamId stream_id, + size_t payload_length) { + events_.push_back( + absl::StrFormat("OnBeginDataForStream %d %d", stream_id, payload_length)); + return true; +} + +bool RecordingHttp2Visitor::OnDataForStream(Http2StreamId stream_id, + absl::string_view data) { + events_.push_back(absl::StrFormat("OnDataForStream %d %s", stream_id, data)); + return true; +} + +void RecordingHttp2Visitor::OnEndStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnEndStream %d", stream_id)); +} + +void RecordingHttp2Visitor::OnRstStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + events_.push_back(absl::StrFormat("OnRstStream %d %s", stream_id, + Http2ErrorCodeToString(error_code))); +} + +void RecordingHttp2Visitor::OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) { + events_.push_back(absl::StrFormat("OnCloseStream %d %s", stream_id, + Http2ErrorCodeToString(error_code))); +} + +void RecordingHttp2Visitor::OnPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, + bool exclusive) { + events_.push_back(absl::StrFormat("OnPriorityForStream %d %d %d %d", + stream_id, parent_stream_id, weight, + exclusive)); +} + +void RecordingHttp2Visitor::OnPing(Http2PingId ping_id, bool is_ack) { + events_.push_back(absl::StrFormat("OnPing %d %d", ping_id, is_ack)); +} + +void RecordingHttp2Visitor::OnPushPromiseForStream( + Http2StreamId stream_id, + Http2StreamId promised_stream_id) { + events_.push_back(absl::StrFormat("OnPushPromiseForStream %d %d", stream_id, + promised_stream_id)); +} + +bool RecordingHttp2Visitor::OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) { + events_.push_back( + absl::StrFormat("OnGoAway %d %s %s", last_accepted_stream_id, + Http2ErrorCodeToString(error_code), opaque_data)); + return true; +} + +void RecordingHttp2Visitor::OnWindowUpdate(Http2StreamId stream_id, + int window_increment) { + events_.push_back( + absl::StrFormat("OnWindowUpdate %d %d", stream_id, window_increment)); +} + +int RecordingHttp2Visitor::OnBeforeFrameSent(uint8_t frame_type, + Http2StreamId stream_id, + size_t length, uint8_t flags) { + events_.push_back(absl::StrFormat("OnBeforeFrameSent %d %d %d %d", frame_type, + stream_id, length, flags)); + return 0; +} + +int RecordingHttp2Visitor::OnFrameSent(uint8_t frame_type, + Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code) { + events_.push_back(absl::StrFormat("OnFrameSent %d %d %d %d %d", frame_type, + stream_id, length, flags, error_code)); + return 0; +} + +bool RecordingHttp2Visitor::OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) { + events_.push_back(absl::StrFormat("OnInvalidFrame %d %s", stream_id, + InvalidFrameErrorToString(error))); + return true; +} + +void RecordingHttp2Visitor::OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) { + events_.push_back(absl::StrFormat("OnBeginMetadataForStream %d %d", stream_id, + payload_length)); +} + +bool RecordingHttp2Visitor::OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) { + events_.push_back( + absl::StrFormat("OnMetadataForStream %d %s", stream_id, metadata)); + return true; +} + +bool RecordingHttp2Visitor::OnMetadataEndForStream(Http2StreamId stream_id) { + events_.push_back(absl::StrFormat("OnMetadataEndForStream %d", stream_id)); + return true; +} + +void RecordingHttp2Visitor::OnErrorDebug(absl::string_view message) { + events_.push_back(absl::StrFormat("OnErrorDebug %s", message)); +} + +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/recording_http2_visitor.h b/gquiche/http2/adapter/recording_http2_visitor.h new file mode 100644 index 00000000..a53447fe --- /dev/null +++ b/gquiche/http2/adapter/recording_http2_visitor.h @@ -0,0 +1,79 @@ +#ifndef QUICHE_HTTP2_ADAPTER_RECORDING_HTTP2_VISITOR_H_ +#define QUICHE_HTTP2_ADAPTER_RECORDING_HTTP2_VISITOR_H_ + +#include +#include +#include + +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { + +// A visitor implementation that records the sequence of callbacks it receives. +class QUICHE_NO_EXPORT RecordingHttp2Visitor : public Http2VisitorInterface { + public: + using Event = std::string; + using EventSequence = std::list; + + // From Http2VisitorInterface + int64_t OnReadyToSend(absl::string_view serialized) override; + void OnConnectionError(ConnectionError error) override; + bool OnFrameHeader(Http2StreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + void OnSettingsStart() override; + void OnSetting(Http2Setting setting) override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + bool OnBeginHeadersForStream(Http2StreamId stream_id) override; + OnHeaderResult OnHeaderForStream(Http2StreamId stream_id, + absl::string_view name, + absl::string_view value) override; + bool OnEndHeadersForStream(Http2StreamId stream_id) override; + bool OnBeginDataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnDataForStream(Http2StreamId stream_id, + absl::string_view data) override; + void OnEndStream(Http2StreamId stream_id) override; + void OnRstStream(Http2StreamId stream_id, Http2ErrorCode error_code) override; + void OnCloseStream(Http2StreamId stream_id, + Http2ErrorCode error_code) override; + void OnPriorityForStream(Http2StreamId stream_id, + Http2StreamId parent_stream_id, + int weight, + bool exclusive) override; + void OnPing(Http2PingId ping_id, bool is_ack) override; + void OnPushPromiseForStream(Http2StreamId stream_id, + Http2StreamId promised_stream_id) override; + bool OnGoAway(Http2StreamId last_accepted_stream_id, + Http2ErrorCode error_code, + absl::string_view opaque_data) override; + void OnWindowUpdate(Http2StreamId stream_id, int window_increment) override; + int OnBeforeFrameSent(uint8_t frame_type, Http2StreamId stream_id, + size_t length, uint8_t flags) override; + int OnFrameSent(uint8_t frame_type, Http2StreamId stream_id, size_t length, + uint8_t flags, uint32_t error_code) override; + bool OnInvalidFrame(Http2StreamId stream_id, + InvalidFrameError error) override; + void OnBeginMetadataForStream(Http2StreamId stream_id, + size_t payload_length) override; + bool OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) override; + bool OnMetadataEndForStream(Http2StreamId stream_id) override; + void OnErrorDebug(absl::string_view message) override; + + const EventSequence& GetEventSequence() const { return events_; } + void Clear() { events_.clear(); } + + private: + EventSequence events_; +}; + +} // namespace test +} // namespace adapter +} // namespace http2 + +#endif // QUICHE_HTTP2_ADAPTER_RECORDING_HTTP2_VISITOR_H_ diff --git a/gquiche/http2/adapter/recording_http2_visitor_test.cc b/gquiche/http2/adapter/recording_http2_visitor_test.cc new file mode 100644 index 00000000..d09724d5 --- /dev/null +++ b/gquiche/http2/adapter/recording_http2_visitor_test.cc @@ -0,0 +1,132 @@ +#include "gquiche/http2/adapter/recording_http2_visitor.h" + +#include + +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/http2/test_tools/http2_random.h" +#include "gquiche/common/platform/api/quiche_test.h" + +namespace http2 { +namespace adapter { +namespace test { +namespace { + +using ::testing::IsEmpty; + +TEST(RecordingHttp2VisitorTest, EmptySequence) { + RecordingHttp2Visitor chocolate_visitor; + RecordingHttp2Visitor vanilla_visitor; + + EXPECT_THAT(chocolate_visitor.GetEventSequence(), IsEmpty()); + EXPECT_THAT(vanilla_visitor.GetEventSequence(), IsEmpty()); + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + chocolate_visitor.OnSettingsStart(); + + EXPECT_THAT(chocolate_visitor.GetEventSequence(), testing::Not(IsEmpty())); + EXPECT_THAT(vanilla_visitor.GetEventSequence(), IsEmpty()); + EXPECT_NE(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + chocolate_visitor.Clear(); + + EXPECT_THAT(chocolate_visitor.GetEventSequence(), IsEmpty()); + EXPECT_THAT(vanilla_visitor.GetEventSequence(), IsEmpty()); + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); +} + +TEST(RecordingHttp2VisitorTest, SameEventsProduceSameSequence) { + RecordingHttp2Visitor chocolate_visitor; + RecordingHttp2Visitor vanilla_visitor; + + // Prepare some random values to deliver with the events. + http2::test::Http2Random random; + const Http2StreamId stream_id = random.Uniform(kMaxStreamId); + const Http2StreamId another_stream_id = random.Uniform(kMaxStreamId); + const size_t length = random.Rand16(); + const uint8_t type = random.Rand8(); + const uint8_t flags = random.Rand8(); + const Http2ErrorCode error_code = static_cast( + random.Uniform(static_cast(Http2ErrorCode::MAX_ERROR_CODE))); + const Http2Setting setting = {.id = random.Rand16(), + .value = random.Rand32()}; + const absl::string_view alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-"; + const std::string some_string = + random.RandStringWithAlphabet(random.Rand8(), alphabet); + const std::string another_string = + random.RandStringWithAlphabet(random.Rand8(), alphabet); + const uint16_t some_int = random.Rand16(); + const bool some_bool = random.OneIn(2); + + // Send the same arbitrary sequence of events to both visitors. + std::list visitors = {&chocolate_visitor, + &vanilla_visitor}; + for (RecordingHttp2Visitor* visitor : visitors) { + visitor->OnConnectionError( + Http2VisitorInterface::ConnectionError::kSendError); + visitor->OnFrameHeader(stream_id, length, type, flags); + visitor->OnSettingsStart(); + visitor->OnSetting(setting); + visitor->OnSettingsEnd(); + visitor->OnSettingsAck(); + visitor->OnBeginHeadersForStream(stream_id); + visitor->OnHeaderForStream(stream_id, some_string, another_string); + visitor->OnEndHeadersForStream(stream_id); + visitor->OnBeginDataForStream(stream_id, length); + visitor->OnDataForStream(stream_id, some_string); + visitor->OnDataForStream(stream_id, another_string); + visitor->OnEndStream(stream_id); + visitor->OnRstStream(stream_id, error_code); + visitor->OnCloseStream(stream_id, error_code); + visitor->OnPriorityForStream(stream_id, another_stream_id, some_int, + some_bool); + visitor->OnPing(some_int, some_bool); + visitor->OnPushPromiseForStream(stream_id, another_stream_id); + visitor->OnGoAway(stream_id, error_code, some_string); + visitor->OnWindowUpdate(stream_id, some_int); + visitor->OnBeginMetadataForStream(stream_id, length); + visitor->OnMetadataForStream(stream_id, some_string); + visitor->OnMetadataForStream(stream_id, another_string); + visitor->OnMetadataEndForStream(stream_id); + } + + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); +} + +TEST(RecordingHttp2VisitorTest, DifferentEventsProduceDifferentSequence) { + RecordingHttp2Visitor chocolate_visitor; + RecordingHttp2Visitor vanilla_visitor; + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + const Http2StreamId stream_id = 1; + const size_t length = 42; + + // Different events with the same method arguments should produce different + // event sequences. + chocolate_visitor.OnBeginDataForStream(stream_id, length); + vanilla_visitor.OnBeginMetadataForStream(stream_id, length); + EXPECT_NE(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + chocolate_visitor.Clear(); + vanilla_visitor.Clear(); + EXPECT_EQ(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); + + // The same events with different method arguments should produce different + // event sequences. + chocolate_visitor.OnBeginHeadersForStream(stream_id); + vanilla_visitor.OnBeginHeadersForStream(stream_id + 2); + EXPECT_NE(chocolate_visitor.GetEventSequence(), + vanilla_visitor.GetEventSequence()); +} + +} // namespace +} // namespace test +} // namespace adapter +} // namespace http2 diff --git a/gquiche/http2/adapter/test_frame_sequence.cc b/gquiche/http2/adapter/test_frame_sequence.cc index 53467f98..49a0fa73 100644 --- a/gquiche/http2/adapter/test_frame_sequence.cc +++ b/gquiche/http2/adapter/test_frame_sequence.cc @@ -1,21 +1,33 @@ #include "gquiche/http2/adapter/test_frame_sequence.h" #include "gquiche/http2/adapter/http2_util.h" +#include "gquiche/http2/adapter/oghttp2_util.h" +#include "gquiche/spdy/core/hpack/hpack_encoder.h" #include "gquiche/spdy/core/spdy_framer.h" namespace http2 { namespace adapter { namespace test { -TestFrameSequence& TestFrameSequence::ClientPreface() { +std::vector ToHeaders( + absl::Span> headers) { + std::vector out; + for (const auto& header : headers) { + out.push_back( + std::make_pair(HeaderRep(header.first), HeaderRep(header.second))); + } + return out; +} + +TestFrameSequence& TestFrameSequence::ClientPreface( + absl::Span settings) { preface_ = spdy::kHttp2ConnectionHeaderPrefix; - frames_.push_back(absl::make_unique()); - return *this; + return Settings(settings); } -TestFrameSequence& TestFrameSequence::ServerPreface() { - frames_.push_back(absl::make_unique()); - return *this; +TestFrameSequence& TestFrameSequence::ServerPreface( + absl::Span settings) { + return Settings(settings); } TestFrameSequence& TestFrameSequence::Data(Http2StreamId stream_id, @@ -39,12 +51,12 @@ TestFrameSequence& TestFrameSequence::RstStream(Http2StreamId stream_id, } TestFrameSequence& TestFrameSequence::Settings( - absl::Span values) { - auto settings = absl::make_unique(); - for (const Http2Setting& setting : values) { - settings->AddSetting(setting.id, setting.value); + absl::Span settings) { + auto settings_frame = absl::make_unique(); + for (const Http2Setting& setting : settings) { + settings_frame->AddSetting(setting.id, setting.value); } - frames_.push_back(std::move(settings)); + frames_.push_back(std::move(settings_frame)); return *this; } @@ -55,6 +67,14 @@ TestFrameSequence& TestFrameSequence::SettingsAck() { return *this; } +TestFrameSequence& TestFrameSequence::PushPromise( + Http2StreamId stream_id, Http2StreamId promised_stream_id, + absl::Span headers) { + frames_.push_back(absl::make_unique( + stream_id, promised_stream_id, ToHeaderBlock(headers))); + return *this; +} + TestFrameSequence& TestFrameSequence::Ping(Http2PingId id) { frames_.push_back(absl::make_unique(id)); return *this; @@ -75,24 +95,47 @@ TestFrameSequence& TestFrameSequence::GoAway(Http2StreamId last_good_stream_id, return *this; } +TestFrameSequence& TestFrameSequence::Headers( + Http2StreamId stream_id, + absl::Span> headers, + bool fin, bool add_continuation) { + return Headers(stream_id, ToHeaders(headers), fin, add_continuation); +} + TestFrameSequence& TestFrameSequence::Headers(Http2StreamId stream_id, spdy::Http2HeaderBlock block, - bool fin) { - auto headers = - absl::make_unique(stream_id, std::move(block)); - headers->set_fin(fin); - frames_.push_back(std::move(headers)); + bool fin, bool add_continuation) { + if (add_continuation) { + // The normal intermediate representations don't allow you to represent a + // nonterminal HEADERS frame explicitly, so we'll need to use + // SpdyUnknownIRs. For simplicity, and in order not to mess up HPACK state, + // the payload will be uncompressed. + spdy::HpackEncoder encoder; + encoder.DisableCompression(); + std::string encoded_block = encoder.EncodeHeaderBlock(block); + const size_t pos = encoded_block.size() / 2; + const uint8_t flags = fin ? 0x1 : 0x0; + frames_.push_back(absl::make_unique( + stream_id, static_cast(spdy::SpdyFrameType::HEADERS), flags, + encoded_block.substr(0, pos))); + + auto continuation = absl::make_unique(stream_id); + continuation->set_end_headers(true); + continuation->take_encoding(encoded_block.substr(pos)); + frames_.push_back(std::move(continuation)); + } else { + auto headers = + absl::make_unique(stream_id, std::move(block)); + headers->set_fin(fin); + frames_.push_back(std::move(headers)); + } return *this; } TestFrameSequence& TestFrameSequence::Headers(Http2StreamId stream_id, absl::Span headers, - bool fin) { - spdy::SpdyHeaderBlock block; - for (const Header& header : headers) { - block[header.first] = header.second; - } - return Headers(stream_id, std::move(block), fin); + bool fin, bool add_continuation) { + return Headers(stream_id, ToHeaderBlock(headers), fin, add_continuation); } TestFrameSequence& TestFrameSequence::WindowUpdate(Http2StreamId stream_id, @@ -111,6 +154,25 @@ TestFrameSequence& TestFrameSequence::Priority(Http2StreamId stream_id, return *this; } +TestFrameSequence& TestFrameSequence::Metadata(Http2StreamId stream_id, + absl::string_view payload, + bool multiple_frames) { + const std::string encoded_payload = MetadataBlockForPayload(payload); + if (multiple_frames) { + const size_t pos = encoded_payload.size() / 2; + frames_.push_back(absl::make_unique( + stream_id, kMetadataFrameType, 0, encoded_payload.substr(0, pos))); + frames_.push_back(absl::make_unique( + stream_id, kMetadataFrameType, kMetadataEndFlag, + encoded_payload.substr(pos))); + } else { + frames_.push_back(absl::make_unique( + stream_id, kMetadataFrameType, kMetadataEndFlag, + std::move(encoded_payload))); + } + return *this; +} + std::string TestFrameSequence::Serialize() { std::string result; if (!preface_.empty()) { @@ -124,6 +186,16 @@ std::string TestFrameSequence::Serialize() { return result; } +std::string TestFrameSequence::MetadataBlockForPayload( + absl::string_view payload) { + // Encode the payload using a header block. + spdy::SpdyHeaderBlock block; + block["example-payload"] = payload; + spdy::HpackEncoder encoder; + encoder.DisableCompression(); + return encoder.EncodeHeaderBlock(block); +} + } // namespace test } // namespace adapter } // namespace http2 diff --git a/gquiche/http2/adapter/test_frame_sequence.h b/gquiche/http2/adapter/test_frame_sequence.h index 8e83748f..649b529b 100644 --- a/gquiche/http2/adapter/test_frame_sequence.h +++ b/gquiche/http2/adapter/test_frame_sequence.h @@ -1,49 +1,68 @@ #ifndef QUICHE_HTTP2_ADAPTER_TEST_FRAME_SEQUENCE_H_ #define QUICHE_HTTP2_ADAPTER_TEST_FRAME_SEQUENCE_H_ +#include #include #include #include #include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/common/platform/api/quiche_export.h" #include "gquiche/spdy/core/spdy_protocol.h" namespace http2 { namespace adapter { namespace test { -class TestFrameSequence { +std::vector QUICHE_NO_EXPORT ToHeaders( + absl::Span> headers); + +class QUICHE_NO_EXPORT TestFrameSequence { public: TestFrameSequence() = default; - TestFrameSequence& ClientPreface(); - TestFrameSequence& ServerPreface(); + TestFrameSequence& ClientPreface( + absl::Span settings = {}); + TestFrameSequence& ServerPreface( + absl::Span settings = {}); TestFrameSequence& Data(Http2StreamId stream_id, absl::string_view payload, bool fin = false, absl::optional padding_length = absl::nullopt); TestFrameSequence& RstStream(Http2StreamId stream_id, Http2ErrorCode error); - TestFrameSequence& Settings(absl::Span values); + TestFrameSequence& Settings(absl::Span settings); TestFrameSequence& SettingsAck(); + TestFrameSequence& PushPromise(Http2StreamId stream_id, + Http2StreamId promised_stream_id, + absl::Span headers); TestFrameSequence& Ping(Http2PingId id); TestFrameSequence& PingAck(Http2PingId id); TestFrameSequence& GoAway(Http2StreamId last_good_stream_id, Http2ErrorCode error, absl::string_view payload = ""); + TestFrameSequence& Headers( + Http2StreamId stream_id, + absl::Span> headers, + bool fin = false, bool add_continuation = false); TestFrameSequence& Headers(Http2StreamId stream_id, - spdy::Http2HeaderBlock block, - bool fin = false); + spdy::Http2HeaderBlock block, bool fin = false, + bool add_continuation = false); TestFrameSequence& Headers(Http2StreamId stream_id, - absl::Span headers, - bool fin = false); + absl::Span headers, bool fin = false, + bool add_continuation = false); TestFrameSequence& WindowUpdate(Http2StreamId stream_id, int32_t delta); TestFrameSequence& Priority(Http2StreamId stream_id, Http2StreamId parent_stream_id, int weight, bool exclusive); + TestFrameSequence& Metadata(Http2StreamId stream_id, + absl::string_view payload, + bool multiple_frames = false); std::string Serialize(); + static std::string MetadataBlockForPayload(absl::string_view); + private: std::string preface_; std::vector> frames_; diff --git a/gquiche/http2/adapter/test_utils.cc b/gquiche/http2/adapter/test_utils.cc index 17657a13..4b658975 100644 --- a/gquiche/http2/adapter/test_utils.cc +++ b/gquiche/http2/adapter/test_utils.cc @@ -1,5 +1,11 @@ #include "gquiche/http2/adapter/test_utils.h" +#include + +#include "absl/strings/str_format.h" +#include "gquiche/http2/adapter/http2_visitor_interface.h" +#include "gquiche/common/quiche_endian.h" +#include "gquiche/spdy/core/hpack/hpack_encoder.h" #include "gquiche/spdy/core/spdy_frame_reader.h" namespace http2 { @@ -7,33 +13,129 @@ namespace adapter { namespace test { namespace { +using ConnectionError = Http2VisitorInterface::ConnectionError; + +} // anonymous namespace + +TestDataFrameSource::TestDataFrameSource(Http2VisitorInterface& visitor, + bool has_fin) + : visitor_(visitor), has_fin_(has_fin) {} + +void TestDataFrameSource::AppendPayload(absl::string_view payload) { + QUICHE_CHECK(!end_data_); + if (!payload.empty()) { + payload_fragments_.push_back(std::string(payload)); + current_fragment_ = payload_fragments_.front(); + } +} + +void TestDataFrameSource::EndData() { end_data_ = true; } + +std::pair TestDataFrameSource::SelectPayloadLength( + size_t max_length) { + if (return_error_) { + return {DataFrameSource::kError, false}; + } + // The stream is done if there's no more data, or if |max_length| is at least + // as large as the remaining data. + const bool end_data = end_data_ && (current_fragment_.empty() || + (payload_fragments_.size() == 1 && + max_length >= current_fragment_.size())); + const int64_t length = std::min(max_length, current_fragment_.size()); + return {length, end_data}; +} + +bool TestDataFrameSource::Send(absl::string_view frame_header, + size_t payload_length) { + QUICHE_LOG_IF(DFATAL, payload_length > current_fragment_.size()) + << "payload_length: " << payload_length + << " current_fragment_size: " << current_fragment_.size(); + const std::string concatenated = + absl::StrCat(frame_header, current_fragment_.substr(0, payload_length)); + const int64_t result = visitor_.OnReadyToSend(concatenated); + if (result < 0) { + // Write encountered error. + visitor_.OnConnectionError(ConnectionError::kSendError); + current_fragment_ = {}; + payload_fragments_.clear(); + return false; + } else if (result == 0) { + // Write blocked. + return false; + } else if (static_cast(result) < concatenated.size()) { + // Probably need to handle this better within this test class. + QUICHE_LOG(DFATAL) + << "DATA frame not fully flushed. Connection will be corrupt!"; + visitor_.OnConnectionError(ConnectionError::kSendError); + current_fragment_ = {}; + payload_fragments_.clear(); + return false; + } + if (payload_length > 0) { + current_fragment_.remove_prefix(payload_length); + } + if (current_fragment_.empty() && !payload_fragments_.empty()) { + payload_fragments_.erase(payload_fragments_.begin()); + if (!payload_fragments_.empty()) { + current_fragment_ = payload_fragments_.front(); + } + } + return true; +} + +std::string EncodeHeaders(const spdy::SpdyHeaderBlock& entries) { + spdy::HpackEncoder encoder; + encoder.DisableCompression(); + return encoder.EncodeHeaderBlock(entries); +} + +TestMetadataSource::TestMetadataSource(const spdy::SpdyHeaderBlock& entries) + : encoded_entries_(EncodeHeaders(entries)) { + remaining_ = encoded_entries_; +} + +std::pair TestMetadataSource::Pack(uint8_t* dest, + size_t dest_len) { + const size_t copied = std::min(dest_len, remaining_.size()); + std::memcpy(dest, remaining_.data(), copied); + remaining_.remove_prefix(copied); + return std::make_pair(copied, remaining_.empty()); +} + +namespace { + using TypeAndOptionalLength = std::pair>; -std::vector> LogFriendly( +std::ostream& operator<<( + std::ostream& os, const std::vector& types_and_lengths) { - std::vector> out; - out.reserve(types_and_lengths.size()); - for (const auto type_and_length : types_and_lengths) { - out.push_back({spdy::FrameTypeToString(type_and_length.first), - type_and_length.second - ? absl::StrCat(type_and_length.second.value()) - : ""}); + for (const auto& type_and_length : types_and_lengths) { + os << "(" << spdy::FrameTypeToString(type_and_length.first) << ", " + << (type_and_length.second ? absl::StrCat(type_and_length.second.value()) + : "") + << ") "; } - return out; + return os; } -// Custom gMock matcher, used to determine if a particular type of frame -// is in a string. This is useful in tests where we want to show that a -// particular control frame type is serialized for sending to the peer. +std::string FrameTypeToString(uint8_t frame_type) { + if (spdy::IsDefinedFrameType(frame_type)) { + return spdy::FrameTypeToString(spdy::ParseFrameType(frame_type)); + } else { + return absl::StrFormat("0x%x", static_cast(frame_type)); + } +} + +// Custom gMock matcher, used to implement EqualsFrames(). class SpdyControlFrameMatcher - : public testing::MatcherInterface { + : public testing::MatcherInterface { public: explicit SpdyControlFrameMatcher( std::vector types_and_lengths) : expected_types_and_lengths_(std::move(types_and_lengths)) {} - bool MatchAndExplain(const std::string s, + bool MatchAndExplain(absl::string_view s, testing::MatchResultListener* listener) const override { spdy::SpdyFrameReader reader(s.data(), s.size()); @@ -43,6 +145,11 @@ class SpdyControlFrameMatcher return false; } } + if (!reader.IsDoneReading()) { + size_t bytes_remaining = s.size() - reader.GetBytesConsumed(); + *listener << "; " << bytes_remaining << " bytes left to read!"; + return false; + } return true; } @@ -70,16 +177,8 @@ class SpdyControlFrameMatcher return false; } - if (!spdy::IsDefinedFrameType(raw_type)) { - *listener << "; expected type " << FrameTypeToString(expected_type) - << " but raw type " << static_cast(raw_type) - << " is not a defined frame type!"; - return false; - } - - spdy::SpdyFrameType actual_type = spdy::ParseFrameType(raw_type); - if (actual_type != expected_type) { - *listener << "; actual type: " << FrameTypeToString(actual_type) + if (raw_type != static_cast(expected_type)) { + *listener << "; actual type: " << FrameTypeToString(raw_type) << " but expected type: " << FrameTypeToString(expected_type); return false; } @@ -91,12 +190,12 @@ class SpdyControlFrameMatcher void DescribeTo(std::ostream* os) const override { *os << "Data contains frames of types in sequence " - << LogFriendly(expected_types_and_lengths_); + << expected_types_and_lengths_; } void DescribeNegationTo(std::ostream* os) const override { *os << "Data does not contain frames of types in sequence " - << LogFriendly(expected_types_and_lengths_); + << expected_types_and_lengths_; } private: @@ -105,13 +204,13 @@ class SpdyControlFrameMatcher } // namespace -testing::Matcher ContainsFrames( +testing::Matcher EqualsFrames( std::vector>> types_and_lengths) { return MakeMatcher(new SpdyControlFrameMatcher(std::move(types_and_lengths))); } -testing::Matcher ContainsFrames( +testing::Matcher EqualsFrames( std::vector types) { std::vector>> types_and_lengths; diff --git a/gquiche/http2/adapter/test_utils.h b/gquiche/http2/adapter/test_utils.h index 94be5867..5f67ffef 100644 --- a/gquiche/http2/adapter/test_utils.h +++ b/gquiche/http2/adapter/test_utils.h @@ -1,25 +1,134 @@ #ifndef QUICHE_HTTP2_ADAPTER_TEST_UTILS_H_ #define QUICHE_HTTP2_ADAPTER_TEST_UTILS_H_ +#include #include #include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "gquiche/http2/adapter/data_source.h" +#include "gquiche/http2/adapter/http2_protocol.h" +#include "gquiche/http2/adapter/mock_http2_visitor.h" +#include "gquiche/common/platform/api/quiche_export.h" #include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/spdy/core/spdy_header_block.h" #include "gquiche/spdy/core/spdy_protocol.h" namespace http2 { namespace adapter { namespace test { -// Matcher that checks whether a string contains HTTP/2 frames of the specified -// ordered sequence of types and lengths. -testing::Matcher ContainsFrames( +class QUICHE_NO_EXPORT DataSavingVisitor + : public testing::StrictMock { + public: + int64_t OnReadyToSend(absl::string_view data) override { + if (has_write_error_) { + return kSendError; + } + if (is_write_blocked_) { + return kSendBlocked; + } + const size_t to_accept = std::min(send_limit_, data.size()); + if (to_accept == 0) { + return kSendBlocked; + } + absl::StrAppend(&data_, data.substr(0, to_accept)); + return to_accept; + } + + bool OnMetadataForStream(Http2StreamId stream_id, + absl::string_view metadata) override { + const bool ret = testing::StrictMock::OnMetadataForStream( + stream_id, metadata); + if (ret) { + auto result = + metadata_map_.try_emplace(stream_id, std::vector()); + result.first->second.push_back(std::string(metadata)); + } + return ret; + } + + const std::vector GetMetadata(Http2StreamId stream_id) { + auto it = metadata_map_.find(stream_id); + if (it == metadata_map_.end()) { + return {}; + } else { + return it->second; + } + } + + const std::string& data() { return data_; } + void Clear() { data_.clear(); } + + void set_send_limit(size_t limit) { send_limit_ = limit; } + + bool is_write_blocked() const { return is_write_blocked_; } + void set_is_write_blocked(bool value) { is_write_blocked_ = value; } + + void set_has_write_error() { has_write_error_ = true; } + + private: + std::string data_; + absl::flat_hash_map> metadata_map_; + size_t send_limit_ = std::numeric_limits::max(); + bool is_write_blocked_ = false; + bool has_write_error_ = false; +}; + +// A test DataFrameSource. Starts out in the empty, blocked state. +class QUICHE_NO_EXPORT TestDataFrameSource : public DataFrameSource { + public: + TestDataFrameSource(Http2VisitorInterface& visitor, bool has_fin); + + void AppendPayload(absl::string_view payload); + void EndData(); + void SimulateError() { return_error_ = true; } + + std::pair SelectPayloadLength(size_t max_length) override; + bool Send(absl::string_view frame_header, size_t payload_length) override; + bool send_fin() const override { return has_fin_; } + + private: + Http2VisitorInterface& visitor_; + std::vector payload_fragments_; + absl::string_view current_fragment_; + // Whether the stream should end with the final frame of data. + const bool has_fin_; + // Whether |payload_fragments_| contains the final segment of data. + bool end_data_ = false; + // Whether SelectPayloadLength() should return an error. + bool return_error_ = false; +}; + +class QUICHE_NO_EXPORT TestMetadataSource : public MetadataSource { + public: + explicit TestMetadataSource(const spdy::SpdyHeaderBlock& entries); + + size_t NumFrames(size_t max_frame_size) const override { + // Round up to the next frame. + return (encoded_entries_.size() + max_frame_size - 1) / max_frame_size; + } + std::pair Pack(uint8_t* dest, size_t dest_len) override; + + private: + const std::string encoded_entries_; + absl::string_view remaining_; +}; + +// These matchers check whether a string consists entirely of HTTP/2 frames of +// the specified ordered sequence. This is useful in tests where we want to show +// that one or more particular frame types are serialized for sending to the +// peer. The match will fail if there are input bytes not consumed by the +// matcher. + +// Requires that frames match both types and lengths. +testing::Matcher EqualsFrames( std::vector>> types_and_lengths); -// Matcher that checks whether a string contains HTTP/2 frames of the specified -// ordered sequence of types. -testing::Matcher ContainsFrames( +// Requires that frames match the specified types. +testing::Matcher EqualsFrames( std::vector types); } // namespace test diff --git a/gquiche/http2/adapter/test_utils_test.cc b/gquiche/http2/adapter/test_utils_test.cc index 4247ebc7..de6a993f 100644 --- a/gquiche/http2/adapter/test_utils_test.cc +++ b/gquiche/http2/adapter/test_utils_test.cc @@ -10,38 +10,37 @@ namespace { using spdy::SpdyFramer; -TEST(ContainsFrames, Empty) { - EXPECT_THAT("", ContainsFrames(std::vector{})); +TEST(EqualsFrames, Empty) { + EXPECT_THAT("", EqualsFrames(std::vector{})); } -TEST(ContainsFrames, SingleFrameWithLength) { +TEST(EqualsFrames, SingleFrameWithLength) { SpdyFramer framer{SpdyFramer::ENABLE_COMPRESSION}; spdy::SpdyPingIR ping{511}; EXPECT_THAT(framer.SerializeFrame(ping), - ContainsFrames({{spdy::SpdyFrameType::PING, 8}})); + EqualsFrames({{spdy::SpdyFrameType::PING, 8}})); spdy::SpdyWindowUpdateIR window_update{1, 101}; EXPECT_THAT(framer.SerializeFrame(window_update), - ContainsFrames({{spdy::SpdyFrameType::WINDOW_UPDATE, 4}})); + EqualsFrames({{spdy::SpdyFrameType::WINDOW_UPDATE, 4}})); spdy::SpdyDataIR data{3, "Some example data, ha ha!"}; EXPECT_THAT(framer.SerializeFrame(data), - ContainsFrames({{spdy::SpdyFrameType::DATA, 25}})); + EqualsFrames({{spdy::SpdyFrameType::DATA, 25}})); } -TEST(ContainsFrames, SingleFrameWithoutLength) { +TEST(EqualsFrames, SingleFrameWithoutLength) { SpdyFramer framer{SpdyFramer::ENABLE_COMPRESSION}; spdy::SpdyRstStreamIR rst_stream{7, spdy::ERROR_CODE_REFUSED_STREAM}; - EXPECT_THAT( - framer.SerializeFrame(rst_stream), - ContainsFrames({{spdy::SpdyFrameType::RST_STREAM, absl::nullopt}})); + EXPECT_THAT(framer.SerializeFrame(rst_stream), + EqualsFrames({{spdy::SpdyFrameType::RST_STREAM, absl::nullopt}})); spdy::SpdyGoAwayIR goaway{13, spdy::ERROR_CODE_ENHANCE_YOUR_CALM, "Consider taking some deep breaths."}; EXPECT_THAT(framer.SerializeFrame(goaway), - ContainsFrames({{spdy::SpdyFrameType::GOAWAY, absl::nullopt}})); + EqualsFrames({{spdy::SpdyFrameType::GOAWAY, absl::nullopt}})); spdy::Http2HeaderBlock block; block[":method"] = "GET"; @@ -49,10 +48,10 @@ TEST(ContainsFrames, SingleFrameWithoutLength) { block[":authority"] = "example.com"; spdy::SpdyHeadersIR headers{17, std::move(block)}; EXPECT_THAT(framer.SerializeFrame(headers), - ContainsFrames({{spdy::SpdyFrameType::HEADERS, absl::nullopt}})); + EqualsFrames({{spdy::SpdyFrameType::HEADERS, absl::nullopt}})); } -TEST(ContainsFrames, MultipleFrames) { +TEST(EqualsFrames, MultipleFrames) { SpdyFramer framer{SpdyFramer::ENABLE_COMPRESSION}; spdy::SpdyPingIR ping{511}; @@ -74,20 +73,48 @@ TEST(ContainsFrames, MultipleFrames) { absl::string_view(framer.SerializeFrame(rst_stream)), absl::string_view(framer.SerializeFrame(goaway)), absl::string_view(framer.SerializeFrame(headers))); + absl::string_view frame_sequence_view = frame_sequence; + EXPECT_THAT(frame_sequence, + EqualsFrames({{spdy::SpdyFrameType::PING, absl::nullopt}, + {spdy::SpdyFrameType::WINDOW_UPDATE, absl::nullopt}, + {spdy::SpdyFrameType::DATA, 25}, + {spdy::SpdyFrameType::RST_STREAM, absl::nullopt}, + {spdy::SpdyFrameType::GOAWAY, 42}, + {spdy::SpdyFrameType::HEADERS, 19}})); + EXPECT_THAT(frame_sequence_view, + EqualsFrames({{spdy::SpdyFrameType::PING, absl::nullopt}, + {spdy::SpdyFrameType::WINDOW_UPDATE, absl::nullopt}, + {spdy::SpdyFrameType::DATA, 25}, + {spdy::SpdyFrameType::RST_STREAM, absl::nullopt}, + {spdy::SpdyFrameType::GOAWAY, 42}, + {spdy::SpdyFrameType::HEADERS, 19}})); EXPECT_THAT( frame_sequence, - ContainsFrames({{spdy::SpdyFrameType::PING, absl::nullopt}, - {spdy::SpdyFrameType::WINDOW_UPDATE, absl::nullopt}, - {spdy::SpdyFrameType::DATA, 25}, - {spdy::SpdyFrameType::RST_STREAM, absl::nullopt}, - {spdy::SpdyFrameType::GOAWAY, 42}, - {spdy::SpdyFrameType::HEADERS, 19}})); + EqualsFrames( + {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, + spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY, spdy::SpdyFrameType::HEADERS})); EXPECT_THAT( - frame_sequence, - ContainsFrames( + frame_sequence_view, + EqualsFrames( {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, spdy::SpdyFrameType::GOAWAY, spdy::SpdyFrameType::HEADERS})); + + // If the final frame type is removed the expectation fails, as there are + // bytes left to read. + EXPECT_THAT( + frame_sequence, + testing::Not(EqualsFrames( + {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, + spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY}))); + EXPECT_THAT( + frame_sequence_view, + testing::Not(EqualsFrames( + {spdy::SpdyFrameType::PING, spdy::SpdyFrameType::WINDOW_UPDATE, + spdy::SpdyFrameType::DATA, spdy::SpdyFrameType::RST_STREAM, + spdy::SpdyFrameType::GOAWAY}))); } } // namespace diff --git a/gquiche/http2/adapter/window_manager.cc b/gquiche/http2/adapter/window_manager.cc index f3122d70..165f5a1c 100644 --- a/gquiche/http2/adapter/window_manager.cc +++ b/gquiche/http2/adapter/window_manager.cc @@ -73,7 +73,6 @@ void WindowManager::MaybeNotifyListener() { } // For the sake of efficiency, we want to send window updates if less than // half of the max quota is available to the peer at any point in time. - // http://google3/gfe/gfe2/stubby/autobahn_fd_wrapper.cc?l=1180-1183&rcl=307416556 const size_t kDesiredMinWindow = limit_ / 2; const size_t kDesiredMinDelta = limit_ / 3; const size_t delta = limit_ - (buffered_ + window_); diff --git a/gquiche/http2/adapter/window_manager.h b/gquiche/http2/adapter/window_manager.h index 277c24f9..5860cfe1 100644 --- a/gquiche/http2/adapter/window_manager.h +++ b/gquiche/http2/adapter/window_manager.h @@ -1,8 +1,12 @@ #ifndef QUICHE_HTTP2_ADAPTER_WINDOW_MANAGER_H_ #define QUICHE_HTTP2_ADAPTER_WINDOW_MANAGER_H_ +#include + #include +#include "gquiche/common/platform/api/quiche_export.h" + namespace http2 { namespace adapter { @@ -12,7 +16,7 @@ class WindowManagerPeer; // This class keeps track of a HTTP/2 flow control window, notifying a listener // when a window update needs to be sent. This class is not thread-safe. -class WindowManager { +class QUICHE_EXPORT_PRIVATE WindowManager { public: // A WindowUpdateListener is invoked when it is time to send a window update. typedef std::function WindowUpdateListener; diff --git a/gquiche/http2/adapter/window_manager_test.cc b/gquiche/http2/adapter/window_manager_test.cc index ce772837..11078ded 100644 --- a/gquiche/http2/adapter/window_manager_test.cc +++ b/gquiche/http2/adapter/window_manager_test.cc @@ -90,21 +90,21 @@ TEST_F(WindowManagerTest, AvoidWindowUnderflow) { EXPECT_EQ(wm_.CurrentWindowSize(), wm_.WindowSizeLimit()); // Don't buffer more than the total window! wm_.MarkDataBuffered(wm_.WindowSizeLimit() + 1); - EXPECT_EQ(wm_.CurrentWindowSize(), 0); + EXPECT_EQ(wm_.CurrentWindowSize(), 0u); } // Window manager should GFE_BUG and avoid buffered underflow. TEST_F(WindowManagerTest, AvoidBufferedUnderflow) { - EXPECT_EQ(peer_.buffered(), 0); + EXPECT_EQ(peer_.buffered(), 0u); // Don't flush more than has been buffered! EXPECT_QUICHE_BUG(wm_.MarkDataFlushed(1), "buffered underflow"); - EXPECT_EQ(peer_.buffered(), 0); + EXPECT_EQ(peer_.buffered(), 0u); wm_.MarkDataBuffered(42); - EXPECT_EQ(peer_.buffered(), 42); + EXPECT_EQ(peer_.buffered(), 42u); // Don't flush more than has been buffered! EXPECT_QUICHE_BUG(wm_.MarkDataFlushed(43), "buffered underflow"); - EXPECT_EQ(peer_.buffered(), 0); + EXPECT_EQ(peer_.buffered(), 0u); } // This test verifies that WindowManager notifies its listener when window is diff --git a/gquiche/http2/core/http2_trace_logging.cc b/gquiche/http2/core/http2_trace_logging.cc new file mode 100644 index 00000000..e4f8984b --- /dev/null +++ b/gquiche/http2/core/http2_trace_logging.cc @@ -0,0 +1,454 @@ +#include "gquiche/http2/core/http2_trace_logging.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gquiche/common/platform/api/quiche_bug_tracker.h" +#include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/spdy/core/spdy_protocol.h" + +// Convenience macros for printing function arguments in log lines in the +// format arg_name=value. +#define FORMAT_ARG(arg) " " #arg "=" << arg +#define FORMAT_INT_ARG(arg) " " #arg "=" << static_cast(arg) + +// Convenience macros for printing Spdy*IR attributes in log lines in the +// format attrib_name=value. +#define FORMAT_ATTR(ir, attrib) " " #attrib "=" << ir.attrib() +#define FORMAT_INT_ATTR(ir, attrib) \ + " " #attrib "=" << static_cast(ir.attrib()) + +namespace { + +// Logs a container, using a user-provided object to log each individual item. +template +struct ContainerLogger { + explicit ContainerLogger(const T& c, ItemLogger l) + : container(c), item_logger(l) {} + + friend std::ostream& operator<<(std::ostream& out, + const ContainerLogger& logger) { + out << "["; + auto begin = logger.container.begin(); + for (auto it = begin; it != logger.container.end(); ++it) { + if (it != begin) { + out << ", "; + } + logger.item_logger.Log(out, *it); + } + out << "]"; + return out; + } + const T& container; + ItemLogger item_logger; +}; + +// Returns a ContainerLogger that will log |container| using |item_logger|. +template +auto LogContainer(const T& container, ItemLogger item_logger) + -> decltype(ContainerLogger(container, item_logger)) { + return ContainerLogger(container, item_logger); +} + +} // anonymous namespace + +#define FORMAT_HEADER_BLOCK(ir) \ + " header_block=" << LogContainer(ir.header_block(), LogHeaderBlockEntry()) + +namespace http2 { + +using spdy::SettingsMap; +using spdy::SpdyAltSvcIR; +using spdy::SpdyContinuationIR; +using spdy::SpdyDataIR; +using spdy::SpdyGoAwayIR; +using spdy::SpdyHeaderBlock; +using spdy::SpdyHeadersIR; +using spdy::SpdyPingIR; +using spdy::SpdyPriorityIR; +using spdy::SpdyPushPromiseIR; +using spdy::SpdyRstStreamIR; +using spdy::SpdySettingsIR; +using spdy::SpdyStreamId; +using spdy::SpdyUnknownIR; +using spdy::SpdyWindowUpdateIR; + +namespace { + +// Defines how elements of SpdyHeaderBlocks are logged. +struct LogHeaderBlockEntry { + void Log(std::ostream& out, + const SpdyHeaderBlock::value_type& entry) const { // NOLINT + out << "\"" << entry.first << "\": \"" << entry.second << "\""; + } +}; + +// Defines how elements of SettingsMap are logged. +struct LogSettingsEntry { + void Log(std::ostream& out, + const SettingsMap::value_type& entry) const { // NOLINT + out << spdy::SettingsIdToString(entry.first) << ": " << entry.second; + } +}; + +// Defines how elements of AlternativeServiceVector are logged. +struct LogAlternativeService { + void Log(std::ostream& out, + const spdy::SpdyAltSvcWireFormat::AlternativeService& altsvc) + const { // NOLINT + out << "{" + << "protocol_id=" << altsvc.protocol_id << " host=" << altsvc.host + << " port=" << altsvc.port + << " max_age_seconds=" << altsvc.max_age_seconds << " version="; + for (auto v : altsvc.version) { + out << v << ","; + } + out << "}"; + } +}; + +} // anonymous namespace + +Http2TraceLogger::Http2TraceLogger(SpdyFramerVisitorInterface* parent, + absl::string_view perspective, + std::function is_enabled, + const void* connection_id) + : wrapped_(parent), + perspective_(perspective), + is_enabled_(std::move(is_enabled)), + connection_id_(connection_id) {} + +Http2TraceLogger::~Http2TraceLogger() { + if (recording_headers_handler_ != nullptr && + !recording_headers_handler_->decoded_block().empty()) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "connection_id=" << connection_id_ + << " Received headers that were never logged! keys/values:" + << recording_headers_handler_->decoded_block().DebugString(); + } +} + +void Http2TraceLogger::OnError(Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnError:" << FORMAT_ARG(connection_id_) + << ", error=" << Http2DecoderAdapter::SpdyFramerErrorToString(error); + wrapped_->OnError(error, detailed_error); +} + +void Http2TraceLogger::OnCommonHeader(SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnCommonHeader:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(length) << FORMAT_INT_ARG(type) + << FORMAT_INT_ARG(flags); + wrapped_->OnCommonHeader(stream_id, length, type, flags); +} + +void Http2TraceLogger::OnDataFrameHeader(SpdyStreamId stream_id, size_t length, + bool fin) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnDataFrameHeader:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(length) << FORMAT_ARG(fin); + wrapped_->OnDataFrameHeader(stream_id, length, fin); +} + +void Http2TraceLogger::OnStreamFrameData(SpdyStreamId stream_id, + const char* data, size_t len) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamFrameData:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(len); + wrapped_->OnStreamFrameData(stream_id, data, len); +} + +void Http2TraceLogger::OnStreamEnd(SpdyStreamId stream_id) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamEnd:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id); + wrapped_->OnStreamEnd(stream_id); +} + +void Http2TraceLogger::OnStreamPadLength(SpdyStreamId stream_id, size_t value) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamPadLength:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(value); + wrapped_->OnStreamPadLength(stream_id, value); +} + +void Http2TraceLogger::OnStreamPadding(SpdyStreamId stream_id, size_t len) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnStreamPadding:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(len); + wrapped_->OnStreamPadding(stream_id, len); +} + +spdy::SpdyHeadersHandlerInterface* Http2TraceLogger::OnHeaderFrameStart( + SpdyStreamId stream_id) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnHeaderFrameStart:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id); + spdy::SpdyHeadersHandlerInterface* result = + wrapped_->OnHeaderFrameStart(stream_id); + recording_headers_handler_ = + absl::make_unique(result); + result = recording_headers_handler_.get(); + return result; +} + +void Http2TraceLogger::OnHeaderFrameEnd(SpdyStreamId stream_id) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnHeaderFrameEnd:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id); + LogReceivedHeaders(); + wrapped_->OnHeaderFrameEnd(stream_id); + recording_headers_handler_ = nullptr; +} + +void Http2TraceLogger::OnRstStream(SpdyStreamId stream_id, + SpdyErrorCode error_code) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnRstStream:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << " error_code=" << spdy::ErrorCodeToString(error_code); + wrapped_->OnRstStream(stream_id, error_code); +} + +void Http2TraceLogger::OnSettings() { wrapped_->OnSettings(); } + +void Http2TraceLogger::OnSetting(SpdySettingsId id, uint32_t value) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnSetting:" << FORMAT_ARG(connection_id_) + << " id=" << spdy::SettingsIdToString(id) << FORMAT_ARG(value); + wrapped_->OnSetting(id, value); +} + +void Http2TraceLogger::OnSettingsEnd() { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnSettingsEnd:" << FORMAT_ARG(connection_id_); + wrapped_->OnSettingsEnd(); +} + +void Http2TraceLogger::OnSettingsAck() { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnSettingsAck:" << FORMAT_ARG(connection_id_); + wrapped_->OnSettingsAck(); +} + +void Http2TraceLogger::OnPing(SpdyPingId unique_id, bool is_ack) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPing:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(unique_id) + << FORMAT_ARG(is_ack); + wrapped_->OnPing(unique_id, is_ack); +} + +void Http2TraceLogger::OnGoAway(SpdyStreamId last_accepted_stream_id, + SpdyErrorCode error_code) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnGoAway:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(last_accepted_stream_id) + << " error_code=" << spdy::ErrorCodeToString(error_code); + wrapped_->OnGoAway(last_accepted_stream_id, error_code); +} + +bool Http2TraceLogger::OnGoAwayFrameData(const char* goaway_data, size_t len) { + return wrapped_->OnGoAwayFrameData(goaway_data, len); +} + +void Http2TraceLogger::OnHeaders(SpdyStreamId stream_id, bool has_priority, + int weight, SpdyStreamId parent_stream_id, + bool exclusive, bool fin, bool end) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnHeaders:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << FORMAT_ARG(has_priority) << FORMAT_INT_ARG(weight) + << FORMAT_ARG(parent_stream_id) << FORMAT_ARG(exclusive) + << FORMAT_ARG(fin) << FORMAT_ARG(end); + wrapped_->OnHeaders(stream_id, has_priority, weight, parent_stream_id, + exclusive, fin, end); +} + +void Http2TraceLogger::OnWindowUpdate(SpdyStreamId stream_id, + int delta_window_size) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnWindowUpdate:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(delta_window_size); + wrapped_->OnWindowUpdate(stream_id, delta_window_size); +} + +void Http2TraceLogger::OnPushPromise(SpdyStreamId original_stream_id, + SpdyStreamId promised_stream_id, bool end) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPushPromise:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(original_stream_id) << FORMAT_ARG(promised_stream_id) + << FORMAT_ARG(end); + wrapped_->OnPushPromise(original_stream_id, promised_stream_id, end); +} + +void Http2TraceLogger::OnContinuation(SpdyStreamId stream_id, bool end) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnContinuation:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_ARG(end); + wrapped_->OnContinuation(stream_id, end); +} + +void Http2TraceLogger::OnAltSvc( + SpdyStreamId stream_id, absl::string_view origin, + const SpdyAltSvcWireFormat::AlternativeServiceVector& altsvc_vector) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnAltSvc:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << FORMAT_ARG(origin) << " altsvc_vector=" + << LogContainer(altsvc_vector, LogAlternativeService()); + wrapped_->OnAltSvc(stream_id, origin, altsvc_vector); +} + +void Http2TraceLogger::OnPriority(SpdyStreamId stream_id, + SpdyStreamId parent_stream_id, int weight, + bool exclusive) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPriority:" << FORMAT_ARG(connection_id_) << FORMAT_ARG(stream_id) + << FORMAT_ARG(parent_stream_id) << FORMAT_INT_ARG(weight) + << FORMAT_ARG(exclusive); + wrapped_->OnPriority(stream_id, parent_stream_id, weight, exclusive); +} + +void Http2TraceLogger::OnPriorityUpdate( + SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnPriorityUpdate:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(prioritized_stream_id) << FORMAT_ARG(priority_field_value); + wrapped_->OnPriorityUpdate(prioritized_stream_id, priority_field_value); +} + +bool Http2TraceLogger::OnUnknownFrame(SpdyStreamId stream_id, + uint8_t frame_type) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "OnUnknownFrame:" << FORMAT_ARG(connection_id_) + << FORMAT_ARG(stream_id) << FORMAT_INT_ARG(frame_type); + return wrapped_->OnUnknownFrame(stream_id, frame_type); +} + +void Http2TraceLogger::LogReceivedHeaders() const { + if (recording_headers_handler_ == nullptr) { + QUICHE_BUG(bug_2794_1) << "Cannot log headers before creating handler."; + return; + } + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Received headers;" << FORMAT_ARG(connection_id_) << " keys/values:" + << recording_headers_handler_->decoded_block().DebugString() + << " compressed_bytes=" + << recording_headers_handler_->compressed_header_bytes() + << " uncompressed_bytes=" + << recording_headers_handler_->uncompressed_header_bytes(); +} + +void Http2FrameLogger::VisitRstStream(const SpdyRstStreamIR& rst_stream) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyRstStreamIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(rst_stream, stream_id) + << " error_code=" << spdy::ErrorCodeToString(rst_stream.error_code()); +} + +void Http2FrameLogger::VisitSettings(const SpdySettingsIR& settings) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdySettingsIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(settings, is_ack) + << " values=" << LogContainer(settings.values(), LogSettingsEntry()); +} + +void Http2FrameLogger::VisitPing(const SpdyPingIR& ping) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPingIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(ping, id) << FORMAT_ATTR(ping, is_ack); +} + +void Http2FrameLogger::VisitGoAway(const SpdyGoAwayIR& goaway) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyGoAwayIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(goaway, last_good_stream_id) + << " error_code=" << spdy::ErrorCodeToString(goaway.error_code()) + << FORMAT_ATTR(goaway, description); +} + +void Http2FrameLogger::VisitHeaders(const SpdyHeadersIR& headers) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyHeadersIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(headers, stream_id) << FORMAT_ATTR(headers, fin) + << FORMAT_ATTR(headers, has_priority) << FORMAT_INT_ATTR(headers, weight) + << FORMAT_ATTR(headers, parent_stream_id) + << FORMAT_ATTR(headers, exclusive) << FORMAT_ATTR(headers, padded) + << FORMAT_ATTR(headers, padding_payload_len) + << FORMAT_HEADER_BLOCK(headers); +} + +void Http2FrameLogger::VisitWindowUpdate( + const SpdyWindowUpdateIR& window_update) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyWindowUpdateIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(window_update, stream_id) + << FORMAT_ATTR(window_update, delta); +} + +void Http2FrameLogger::VisitPushPromise(const SpdyPushPromiseIR& push_promise) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPushPromiseIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(push_promise, stream_id) << FORMAT_ATTR(push_promise, fin) + << FORMAT_ATTR(push_promise, promised_stream_id) + << FORMAT_ATTR(push_promise, padded) + << FORMAT_ATTR(push_promise, padding_payload_len) + << FORMAT_HEADER_BLOCK(push_promise); +} + +void Http2FrameLogger::VisitContinuation( + const SpdyContinuationIR& continuation) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyContinuationIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(continuation, stream_id) + << FORMAT_ATTR(continuation, end_headers); +} + +void Http2FrameLogger::VisitAltSvc(const SpdyAltSvcIR& altsvc) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyAltSvcIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(altsvc, stream_id) << FORMAT_ATTR(altsvc, origin) + << " altsvc_vector=" + << LogContainer(altsvc.altsvc_vector(), LogAlternativeService()); +} + +void Http2FrameLogger::VisitPriority(const SpdyPriorityIR& priority) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPriorityIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(priority, stream_id) + << FORMAT_ATTR(priority, parent_stream_id) + << FORMAT_INT_ATTR(priority, weight) << FORMAT_ATTR(priority, exclusive); +} + +void Http2FrameLogger::VisitData(const SpdyDataIR& data) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyDataIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(data, stream_id) << FORMAT_ATTR(data, fin) + << " data_len=" << data.data_len() << FORMAT_ATTR(data, padded) + << FORMAT_ATTR(data, padding_payload_len); +} + +void Http2FrameLogger::VisitPriorityUpdate( + const spdy::SpdyPriorityUpdateIR& priority_update) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyPriorityUpdateIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(priority_update, stream_id) + << FORMAT_ATTR(priority_update, prioritized_stream_id) + << FORMAT_ATTR(priority_update, priority_field_value); +} + +void Http2FrameLogger::VisitAcceptCh( + const spdy::SpdyAcceptChIR& /*accept_ch*/) { + QUICHE_BUG(bug_2794_2) + << "Sending ACCEPT_CH frames is currently unimplemented."; +} + +void Http2FrameLogger::VisitUnknown(const SpdyUnknownIR& ir) { + HTTP2_TRACE_LOG(perspective_, is_enabled_) + << "Wrote SpdyUnknownIR:" << FORMAT_ARG(connection_id_) + << FORMAT_ATTR(ir, stream_id) << FORMAT_INT_ATTR(ir, type) + << FORMAT_INT_ATTR(ir, flags) << FORMAT_ATTR(ir, length); +} + +} // namespace http2 diff --git a/gquiche/http2/core/http2_trace_logging.h b/gquiche/http2/core/http2_trace_logging.h new file mode 100644 index 00000000..d3bb1c3c --- /dev/null +++ b/gquiche/http2/core/http2_trace_logging.h @@ -0,0 +1,140 @@ +// Classes and utilities for supporting HTTP/2 trace logging, which logs +// information about all control and data frames sent and received over +// HTTP/2 connections. + +#ifndef QUICHE_HTTP2_CORE_HTTP2_TRACE_LOGGING_H_ +#define QUICHE_HTTP2_CORE_HTTP2_TRACE_LOGGING_H_ + +#include + +#include "absl/strings/string_view.h" +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/spdy/core/http2_frame_decoder_adapter.h" +#include "gquiche/spdy/core/recording_headers_handler.h" +#include "gquiche/spdy/core/spdy_headers_handler_interface.h" +#include "gquiche/spdy/core/spdy_protocol.h" + +// Logging macro to use for all HTTP/2 trace logging. Iff trace logging is +// enabled, logs at level INFO with a common prefix prepended (to facilitate +// post-hoc filtering of trace logging output). +#define HTTP2_TRACE_LOG(perspective, is_enabled) \ + QUICHE_LOG_IF(INFO, is_enabled()) << "[HTTP2_TRACE " << perspective << "] " + +namespace http2 { + +// Intercepts deframing events to provide detailed logs. Intended to be used for +// manual debugging. +// +// Note any new methods in SpdyFramerVisitorInterface MUST be overridden here to +// properly forward the event. This could be ensured by making every event in +// SpdyFramerVisitorInterface a pure virtual. +class QUICHE_EXPORT_PRIVATE Http2TraceLogger + : public spdy::SpdyFramerVisitorInterface { + public: + typedef spdy::SpdyAltSvcWireFormat SpdyAltSvcWireFormat; + typedef spdy::SpdyErrorCode SpdyErrorCode; + typedef spdy::SpdyFramerVisitorInterface SpdyFramerVisitorInterface; + typedef spdy::SpdyPingId SpdyPingId; + typedef spdy::SpdyPriority SpdyPriority; + typedef spdy::SpdySettingsId SpdySettingsId; + typedef spdy::SpdyStreamId SpdyStreamId; + + Http2TraceLogger(SpdyFramerVisitorInterface* parent, + absl::string_view perspective, + std::function is_enabled, const void* connection_id); + ~Http2TraceLogger() override; + + Http2TraceLogger(const Http2TraceLogger&) = delete; + Http2TraceLogger& operator=(const Http2TraceLogger&) = delete; + + void OnError(http2::Http2DecoderAdapter::SpdyFramerError error, + std::string detailed_error) override; + void OnCommonHeader(SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + spdy::SpdyHeadersHandlerInterface* OnHeaderFrameStart( + SpdyStreamId stream_id) override; + void OnHeaderFrameEnd(SpdyStreamId stream_id) override; + void OnDataFrameHeader(SpdyStreamId stream_id, size_t length, + bool fin) override; + void OnStreamFrameData(SpdyStreamId stream_id, const char* data, + size_t len) override; + void OnStreamEnd(SpdyStreamId stream_id) override; + void OnStreamPadLength(SpdyStreamId stream_id, size_t value) override; + void OnStreamPadding(SpdyStreamId stream_id, size_t len) override; + void OnRstStream(SpdyStreamId stream_id, SpdyErrorCode error_code) override; + void OnSetting(spdy::SpdySettingsId id, uint32_t value) override; + void OnPing(SpdyPingId unique_id, bool is_ack) override; + void OnSettings() override; + void OnSettingsEnd() override; + void OnSettingsAck() override; + void OnGoAway(SpdyStreamId last_accepted_stream_id, + SpdyErrorCode error_code) override; + bool OnGoAwayFrameData(const char* goaway_data, size_t len) override; + void OnHeaders(SpdyStreamId stream_id, bool has_priority, int weight, + SpdyStreamId parent_stream_id, bool exclusive, bool fin, + bool end) override; + void OnWindowUpdate(SpdyStreamId stream_id, int delta_window_size) override; + void OnPushPromise(SpdyStreamId stream_id, SpdyStreamId promised_stream_id, + bool end) override; + void OnContinuation(SpdyStreamId stream_id, bool end) override; + void OnAltSvc(SpdyStreamId stream_id, absl::string_view origin, + const SpdyAltSvcWireFormat::AlternativeServiceVector& + altsvc_vector) override; + void OnPriority(SpdyStreamId stream_id, SpdyStreamId parent_stream_id, + int weight, bool exclusive) override; + void OnPriorityUpdate(SpdyStreamId prioritized_stream_id, + absl::string_view priority_field_value) override; + bool OnUnknownFrame(SpdyStreamId stream_id, uint8_t frame_type) override; + + private: + void LogReceivedHeaders() const; + + std::unique_ptr recording_headers_handler_; + + SpdyFramerVisitorInterface* wrapped_; + const absl::string_view perspective_; + const std::function is_enabled_; + const void* connection_id_; +}; + +// Visitor to log control frames that have been written. +class QUICHE_EXPORT_PRIVATE Http2FrameLogger : public spdy::SpdyFrameVisitor { + public: + // This class will preface all of its log messages with the value of + // |connection_id| in hexadecimal. + Http2FrameLogger(absl::string_view perspective, + std::function is_enabled, const void* connection_id) + : perspective_(perspective), + is_enabled_(std::move(is_enabled)), + connection_id_(connection_id) {} + + Http2FrameLogger(const Http2FrameLogger&) = delete; + Http2FrameLogger& operator=(const Http2FrameLogger&) = delete; + + void VisitRstStream(const spdy::SpdyRstStreamIR& rst_stream) override; + void VisitSettings(const spdy::SpdySettingsIR& settings) override; + void VisitPing(const spdy::SpdyPingIR& ping) override; + void VisitGoAway(const spdy::SpdyGoAwayIR& goaway) override; + void VisitHeaders(const spdy::SpdyHeadersIR& headers) override; + void VisitWindowUpdate( + const spdy::SpdyWindowUpdateIR& window_update) override; + void VisitPushPromise(const spdy::SpdyPushPromiseIR& push_promise) override; + void VisitContinuation(const spdy::SpdyContinuationIR& continuation) override; + void VisitAltSvc(const spdy::SpdyAltSvcIR& altsvc) override; + void VisitPriority(const spdy::SpdyPriorityIR& priority) override; + void VisitData(const spdy::SpdyDataIR& data) override; + void VisitPriorityUpdate( + const spdy::SpdyPriorityUpdateIR& priority_update) override; + void VisitAcceptCh(const spdy::SpdyAcceptChIR& accept_ch) override; + void VisitUnknown(const spdy::SpdyUnknownIR& ir) override; + + private: + const absl::string_view perspective_; + const std::function is_enabled_; + const void* connection_id_; +}; + +} // namespace http2 + +#endif // QUICHE_HTTP2_CORE_HTTP2_TRACE_LOGGING_H_ diff --git a/gquiche/http2/core/priority_write_scheduler.h b/gquiche/http2/core/priority_write_scheduler.h index 5113bb9f..3f4040ed 100644 --- a/gquiche/http2/core/priority_write_scheduler.h +++ b/gquiche/http2/core/priority_write_scheduler.h @@ -18,6 +18,7 @@ #include "absl/strings/str_cat.h" #include "gquiche/http2/core/write_scheduler.h" #include "gquiche/common/platform/api/quiche_bug_tracker.h" +#include "gquiche/common/platform/api/quiche_export.h" #include "gquiche/common/platform/api/quiche_logging.h" #include "gquiche/spdy/core/spdy_protocol.h" @@ -37,9 +38,9 @@ class PriorityWriteSchedulerPeer; // that priority that are ready to write, as well as a timestamp of the last // I/O event that occurred for a stream of that priority. // -// DO NOT USE. Deprecated. template -class PriorityWriteScheduler : public WriteScheduler { +class QUICHE_EXPORT_PRIVATE PriorityWriteScheduler + : public WriteScheduler { public: using typename WriteScheduler::StreamPrecedenceType; @@ -286,7 +287,7 @@ class PriorityWriteScheduler : public WriteScheduler { // State kept for all registered streams. All ready streams have ready = true // and should be present in priority_infos_[priority].ready_list. - struct StreamInfo { + struct QUICHE_EXPORT_PRIVATE StreamInfo { spdy::SpdyPriority priority; StreamIdType stream_id; bool ready; @@ -296,7 +297,7 @@ class PriorityWriteScheduler : public WriteScheduler { using ReadyList = std::deque; // State kept for each priority level. - struct PriorityInfo { + struct QUICHE_EXPORT_PRIVATE PriorityInfo { // IDs of streams that are ready to write. ReadyList ready_list; // Time of latest write event for stream of this priority, in microseconds. diff --git a/gquiche/http2/core/write_scheduler.h b/gquiche/http2/core/write_scheduler.h index 80ac4686..43729e35 100644 --- a/gquiche/http2/core/write_scheduler.h +++ b/gquiche/http2/core/write_scheduler.h @@ -22,14 +22,6 @@ namespace http2 { // where (writable) higher-priority streams are always given precedence // over lower-priority streams. // -// Http2PriorityWriteScheduler: implements SPDY priority-based stream -// scheduling coupled with the HTTP/2 stream dependency model. This is only -// intended as a transitional step towards Http2WeightedWriteScheduler. -// -// Http2WeightedWriteScheduler (coming soon): implements the HTTP/2 stream -// dependency model with weighted stream scheduling, fully conforming to -// RFC 7540. -// // The type used to represent stream IDs (StreamIdType) is templated in order // to allow for use by both SPDY and QUIC codebases. It must be a POD that // supports comparison (i.e., a numeric type). diff --git a/gquiche/http2/hpack/decoder/hpack_decoder.cc b/gquiche/http2/hpack/decoder/hpack_decoder.cc index 223186ee..806a4204 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder.cc +++ b/gquiche/http2/hpack/decoder/hpack_decoder.cc @@ -110,10 +110,6 @@ bool HpackDecoder::DetectError() { return error_ != HpackDecodingError::kOk; } -size_t HpackDecoder::EstimateMemoryUsage() const { - return Http2EstimateMemoryUsage(entry_buffer_); -} - void HpackDecoder::ReportError(HpackDecodingError error, std::string detailed_error) { HTTP2_DVLOG(3) << "HpackDecoder::ReportError is new=" diff --git a/gquiche/http2/hpack/decoder/hpack_decoder.h b/gquiche/http2/hpack/decoder/hpack_decoder.h index 68c3fcff..af729efe 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder.h +++ b/gquiche/http2/hpack/decoder/hpack_decoder.h @@ -69,6 +69,11 @@ class QUICHE_EXPORT_PRIVATE HpackDecoder { // decoding the SETTINGS ACK, and before the next HPACK block is decoded. void ApplyHeaderTableSizeSetting(uint32_t max_header_table_size); + // Returns the most recently applied value of SETTINGS_HEADER_TABLE_SIZE. + size_t GetCurrentHeaderTableSizeSetting() const { + return decoder_state_.GetCurrentHeaderTableSizeSetting(); + } + // Prepares the decoder for decoding a new HPACK block, and announces this to // its listener. Returns true if OK to continue with decoding, false if an // error has been detected, which for StartDecodingBlock means the error was @@ -92,12 +97,13 @@ class QUICHE_EXPORT_PRIVATE HpackDecoder { // detected. bool DetectError(); + size_t GetDynamicTableSize() const { + return decoder_state_.GetDynamicTableSize(); + } + // Error code if an error has occurred, HpackDecodingError::kOk otherwise. HpackDecodingError error() const { return error_; } - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - std::string detailed_error() const { return detailed_error_; } private: diff --git a/gquiche/http2/hpack/decoder/hpack_decoder_state.h b/gquiche/http2/hpack/decoder/hpack_decoder_state.h index abb5ae9b..c70dfbf2 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder_state.h +++ b/gquiche/http2/hpack/decoder/hpack_decoder_state.h @@ -57,6 +57,11 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderState : public HpackWholeEntryListener { // decoding the SETTINGS ACK, and before the next HPACK block is decoded. void ApplyHeaderTableSizeSetting(uint32_t max_header_table_size); + // Returns the most recently applied value of SETTINGS_HEADER_TABLE_SIZE. + size_t GetCurrentHeaderTableSizeSetting() const { + return final_header_table_size_; + } + // OnHeaderBlockStart notifies this object that we're starting to decode the // HPACK payload of a HEADERS or PUSH_PROMISE frame. void OnHeaderBlockStart(); @@ -83,6 +88,10 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderState : public HpackWholeEntryListener { // No further callbacks will be made to the listener. HpackDecodingError error() const { return error_; } + size_t GetDynamicTableSize() const { + return decoder_tables_.current_header_table_size(); + } + const HpackDecoderTables& decoder_tables_for_test() const { return decoder_tables_; } diff --git a/gquiche/http2/hpack/decoder/hpack_decoder_state_test.cc b/gquiche/http2/hpack/decoder/hpack_decoder_state_test.cc index 20eef513..5b01eecc 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder_state_test.cc +++ b/gquiche/http2/hpack/decoder/hpack_decoder_state_test.cc @@ -427,8 +427,10 @@ TEST_F(HpackDecoderStateTest, OptionalTableSizeChanges) { // Confirm that required size updates are indeed required before headers. TEST_F(HpackDecoderStateTest, RequiredTableSizeChangeBeforeHeader) { + EXPECT_EQ(4096u, decoder_state_.GetCurrentHeaderTableSizeSetting()); decoder_state_.ApplyHeaderTableSizeSetting(1024); decoder_state_.ApplyHeaderTableSizeSetting(2048); + EXPECT_EQ(2048u, decoder_state_.GetCurrentHeaderTableSizeSetting()); // First provide the required update, and an allowed second update. SendStartAndVerifyCallback(); @@ -442,6 +444,7 @@ TEST_F(HpackDecoderStateTest, RequiredTableSizeChangeBeforeHeader) { // Another HPACK block, but this time missing the required size update. decoder_state_.ApplyHeaderTableSizeSetting(1024); + EXPECT_EQ(1024u, decoder_state_.GetCurrentHeaderTableSizeSetting()); SendStartAndVerifyCallback(); EXPECT_CALL(listener_, OnHeaderErrorDetected(Eq("Missing dynamic table size update"))); diff --git a/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc b/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc index e9ffccf4..058f53f7 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc +++ b/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.cc @@ -7,7 +7,6 @@ #include #include "gquiche/http2/platform/api/http2_bug_tracker.h" -#include "gquiche/http2/platform/api/http2_estimate_memory_usage.h" #include "gquiche/http2/platform/api/http2_logging.h" namespace http2 { @@ -232,10 +231,6 @@ void HpackDecoderStringBuffer::OutputDebugStringTo(std::ostream& out) const { out << "}"; } -size_t HpackDecoderStringBuffer::EstimateMemoryUsage() const { - return Http2EstimateMemoryUsage(buffer_); -} - std::ostream& operator<<(std::ostream& out, const HpackDecoderStringBuffer& v) { v.OutputDebugStringTo(out); return out; diff --git a/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.h b/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.h index 3565abf4..3779de96 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.h +++ b/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer.h @@ -67,9 +67,6 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderStringBuffer { Backing backing_for_testing() const { return backing_; } void OutputDebugStringTo(std::ostream& out) const; - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - private: // Storage for the string being buffered, if buffering is necessary // (e.g. if Huffman encoded, buffer_ is storage for the decoded string). diff --git a/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc b/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc index 779b04b3..11114372 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc +++ b/gquiche/http2/hpack/decoder/hpack_decoder_string_buffer_test.cc @@ -10,7 +10,6 @@ #include "absl/strings/escaping.h" #include "gquiche/http2/platform/api/http2_logging.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/http2/platform/api/http2_test_helpers.h" #include "gquiche/common/platform/api/quiche_test.h" diff --git a/gquiche/http2/hpack/decoder/hpack_decoder_tables.h b/gquiche/http2/hpack/decoder/hpack_decoder_tables.h index b271f56a..6d9caa5b 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder_tables.h +++ b/gquiche/http2/hpack/decoder/hpack_decoder_tables.h @@ -25,8 +25,8 @@ #include #include "gquiche/http2/http2_constants.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/quiche_circular_deque.h" namespace http2 { namespace test { @@ -107,7 +107,7 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderDynamicTable { // Removes the oldest dynamic table entry. void RemoveLastEntry(); - quic::QuicCircularDeque table_; + quiche::QuicheCircularDeque table_; // The last received DynamicTableSizeUpdate value, initialized to // SETTINGS_HEADER_TABLE_SIZE. diff --git a/gquiche/http2/hpack/decoder/hpack_decoder_test.cc b/gquiche/http2/hpack/decoder/hpack_decoder_test.cc index fdb5b61e..a97c07ca 100644 --- a/gquiche/http2/hpack/decoder/hpack_decoder_test.cc +++ b/gquiche/http2/hpack/decoder/hpack_decoder_test.cc @@ -961,11 +961,13 @@ TEST_P(HpackDecoderTest, ProcessesOptionalTableSizeUpdates) { // Confirm that the table size can be changed when required, but at most twice. TEST_P(HpackDecoderTest, ProcessesRequiredTableSizeUpdate) { + EXPECT_EQ(4096u, decoder_.GetCurrentHeaderTableSizeSetting()); // One update required, two allowed, one provided, followed by a header. decoder_.ApplyHeaderTableSizeSetting(1024); decoder_.ApplyHeaderTableSizeSetting(2048); EXPECT_EQ(Http2SettingsInfo::DefaultHeaderTableSize(), header_table_size_limit()); + EXPECT_EQ(2048u, decoder_.GetCurrentHeaderTableSizeSetting()); { HpackBlockBuilder hbb; hbb.AppendDynamicTableSizeUpdate(1024); @@ -979,6 +981,7 @@ TEST_P(HpackDecoderTest, ProcessesRequiredTableSizeUpdate) { // One update required, two allowed, two provided, followed by a header. decoder_.ApplyHeaderTableSizeSetting(1000); decoder_.ApplyHeaderTableSizeSetting(1500); + EXPECT_EQ(1500u, decoder_.GetCurrentHeaderTableSizeSetting()); { HpackBlockBuilder hbb; hbb.AppendDynamicTableSizeUpdate(500); @@ -994,6 +997,7 @@ TEST_P(HpackDecoderTest, ProcessesRequiredTableSizeUpdate) { // The third update is rejected, so the final size is 1000, not 500. decoder_.ApplyHeaderTableSizeSetting(500); decoder_.ApplyHeaderTableSizeSetting(1000); + EXPECT_EQ(1000u, decoder_.GetCurrentHeaderTableSizeSetting()); { HpackBlockBuilder hbb; hbb.AppendDynamicTableSizeUpdate(200); @@ -1011,6 +1015,7 @@ TEST_P(HpackDecoderTest, ProcessesRequiredTableSizeUpdate) { EXPECT_EQ(0u, current_header_table_size()); EXPECT_TRUE(header_entries_.empty()); } + EXPECT_EQ(1000u, decoder_.GetCurrentHeaderTableSizeSetting()); // Now that an error has been detected, StartDecodingBlock should return // false. EXPECT_FALSE(decoder_.StartDecodingBlock()); diff --git a/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc b/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc index 35d54c3a..486863c1 100644 --- a/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc +++ b/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.cc @@ -10,7 +10,7 @@ #include "gquiche/http2/platform/api/http2_flags.h" #include "gquiche/http2/platform/api/http2_logging.h" #include "gquiche/http2/platform/api/http2_macros.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace http2 { @@ -35,10 +35,6 @@ void HpackWholeEntryBuffer::BufferStringsIfUnbuffered() { value_.BufferStringIfUnbuffered(); } -size_t HpackWholeEntryBuffer::EstimateMemoryUsage() const { - return Http2EstimateMemoryUsage(name_) + Http2EstimateMemoryUsage(value_); -} - void HpackWholeEntryBuffer::OnIndexedHeader(size_t index) { HTTP2_DVLOG(2) << "HpackWholeEntryBuffer::OnIndexedHeader: index=" << index; listener_->OnIndexedHeader(index); @@ -71,7 +67,8 @@ void HpackWholeEntryBuffer::OnNameStart(bool huffman_encoded, size_t len) { void HpackWholeEntryBuffer::OnNameData(const char* data, size_t len) { HTTP2_DVLOG(2) << "HpackWholeEntryBuffer::OnNameData: len=" << len << " data:\n" - << Http2HexDump(absl::string_view(data, len)); + << quiche::QuicheTextUtils::HexDump( + absl::string_view(data, len)); QUICHE_DCHECK_EQ(maybe_name_index_, 0u); if (!error_detected_ && !name_.OnData(data, len)) { ReportError(HpackDecodingError::kNameHuffmanError, ""); @@ -108,7 +105,8 @@ void HpackWholeEntryBuffer::OnValueStart(bool huffman_encoded, size_t len) { void HpackWholeEntryBuffer::OnValueData(const char* data, size_t len) { HTTP2_DVLOG(2) << "HpackWholeEntryBuffer::OnValueData: len=" << len << " data:\n" - << Http2HexDump(absl::string_view(data, len)); + << quiche::QuicheTextUtils::HexDump( + absl::string_view(data, len)); if (!error_detected_ && !value_.OnData(data, len)) { ReportError(HpackDecodingError::kValueHuffmanError, ""); HTTP2_CODE_COUNT_N(decompress_failure_3, 22, 23); diff --git a/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.h b/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.h index df64f205..03251181 100644 --- a/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.h +++ b/gquiche/http2/hpack/decoder/hpack_whole_entry_buffer.h @@ -63,9 +63,6 @@ class QUICHE_EXPORT_PRIVATE HpackWholeEntryBuffer // no further callbacks will be made to the listener. bool error_detected() const { return error_detected_; } - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - // Implement the HpackEntryDecoderListener methods. void OnIndexedHeader(size_t index) override; diff --git a/gquiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc b/gquiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc index 7945db2b..8d21a25f 100644 --- a/gquiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc +++ b/gquiche/http2/hpack/huffman/hpack_huffman_decoder_test.cc @@ -12,7 +12,6 @@ #include "absl/strings/escaping.h" #include "gquiche/http2/decoder/decode_buffer.h" #include "gquiche/http2/decoder/decode_status.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/http2/platform/api/http2_test_helpers.h" #include "gquiche/http2/tools/random_decoder_test.h" #include "gquiche/common/platform/api/quiche_test.h" diff --git a/gquiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc b/gquiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc index eed9ff41..419acb5f 100644 --- a/gquiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc +++ b/gquiche/http2/hpack/huffman/hpack_huffman_encoder_test.cc @@ -6,7 +6,6 @@ #include "absl/base/macros.h" #include "absl/strings/escaping.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/common/platform/api/quiche_test.h" namespace http2 { diff --git a/gquiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc b/gquiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc index 32197add..45b3a2f7 100644 --- a/gquiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc +++ b/gquiche/http2/hpack/huffman/hpack_huffman_transcoder_test.cc @@ -11,13 +11,15 @@ #include "gquiche/http2/decoder/decode_status.h" #include "gquiche/http2/hpack/huffman/hpack_huffman_decoder.h" #include "gquiche/http2/hpack/huffman/hpack_huffman_encoder.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/http2/tools/random_decoder_test.h" #include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/common/quiche_text_utils.h" using ::testing::AssertionResult; using ::testing::AssertionSuccess; -using ::testing::tuple; +using ::testing::Combine; +using ::testing::Range; +using ::testing::Values; namespace http2 { namespace test { @@ -77,7 +79,7 @@ class HpackHuffmanTranscoderTest : public RandomDecoderTest { std::string encoded; HuffmanEncode(plain, encoded_size, &encoded); VERIFY_EQ(encoded_size, encoded.size()); - if (expected_huffman.size() > 0 || plain.empty()) { + if (!expected_huffman.empty() || plain.empty()) { VERIFY_EQ(encoded, expected_huffman); } input_bytes_expected_ = encoded.size(); @@ -114,8 +116,8 @@ TEST_F(HpackHuffmanTranscoderTest, RoundTripRandomAsciiNonControlString) { const std::string s = RandomAsciiNonControlString(length); ASSERT_TRUE(TranscodeAndValidateSeveralWays(s)) << "Unable to decode:\n\n" - << Http2HexDump(s) << "\n\noutput_buffer_:\n" - << Http2HexDump(output_buffer_); + << quiche::QuicheTextUtils::HexDump(s) << "\n\noutput_buffer_:\n" + << quiche::QuicheTextUtils::HexDump(output_buffer_); } } @@ -124,15 +126,15 @@ TEST_F(HpackHuffmanTranscoderTest, RoundTripRandomBytes) { const std::string s = RandomBytes(length); ASSERT_TRUE(TranscodeAndValidateSeveralWays(s)) << "Unable to decode:\n\n" - << Http2HexDump(s) << "\n\noutput_buffer_:\n" - << Http2HexDump(output_buffer_); + << quiche::QuicheTextUtils::HexDump(s) << "\n\noutput_buffer_:\n" + << quiche::QuicheTextUtils::HexDump(output_buffer_); } } // Two parameters: decoder choice, and the character to round-trip. class HpackHuffmanTranscoderAdjacentCharTest : public HpackHuffmanTranscoderTest, - public ::testing::WithParamInterface { + public testing::WithParamInterface { protected: HpackHuffmanTranscoderAdjacentCharTest() : c_(static_cast(GetParam())) {} @@ -141,8 +143,7 @@ class HpackHuffmanTranscoderAdjacentCharTest }; INSTANTIATE_TEST_SUITE_P(HpackHuffmanTranscoderAdjacentCharTest, - HpackHuffmanTranscoderAdjacentCharTest, - ::testing::Range(0, 256)); + HpackHuffmanTranscoderAdjacentCharTest, Range(0, 256)); // Test c_ adjacent to every other character, both before and after. TEST_P(HpackHuffmanTranscoderAdjacentCharTest, RoundTripAdjacentChar) { @@ -158,11 +159,11 @@ TEST_P(HpackHuffmanTranscoderAdjacentCharTest, RoundTripAdjacentChar) { // Two parameters: character to repeat, number of repeats. class HpackHuffmanTranscoderRepeatedCharTest : public HpackHuffmanTranscoderTest, - public ::testing::WithParamInterface> { + public testing::WithParamInterface> { protected: HpackHuffmanTranscoderRepeatedCharTest() - : c_(static_cast(::testing::get<0>(GetParam()))), - length_(::testing::get<1>(GetParam())) {} + : c_(static_cast(std::get<0>(GetParam()))), + length_(std::get<1>(GetParam())) {} std::string MakeString() { return std::string(length_, c_); } private: @@ -172,9 +173,7 @@ class HpackHuffmanTranscoderRepeatedCharTest INSTANTIATE_TEST_SUITE_P(HpackHuffmanTranscoderRepeatedCharTest, HpackHuffmanTranscoderRepeatedCharTest, - ::testing::Combine(::testing::Range(0, 256), - ::testing::Values(1, 2, 3, 4, 8, 16, - 32))); + Combine(Range(0, 256), Values(1, 2, 3, 4, 8, 16, 32))); TEST_P(HpackHuffmanTranscoderRepeatedCharTest, RoundTripRepeatedChar) { ASSERT_TRUE(TranscodeAndValidateSeveralWays(MakeString())); diff --git a/gquiche/http2/hpack/tools/hpack_block_builder_test.cc b/gquiche/http2/hpack/tools/hpack_block_builder_test.cc index 423b77d4..884b3fdd 100644 --- a/gquiche/http2/hpack/tools/hpack_block_builder_test.cc +++ b/gquiche/http2/hpack/tools/hpack_block_builder_test.cc @@ -5,7 +5,6 @@ #include "gquiche/http2/hpack/tools/hpack_block_builder.h" #include "absl/strings/escaping.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/common/platform/api/quiche_test.h" namespace http2 { diff --git a/gquiche/http2/hpack/tools/hpack_example.cc b/gquiche/http2/hpack/tools/hpack_example.cc index e58b303d..a02e8ebc 100644 --- a/gquiche/http2/hpack/tools/hpack_example.cc +++ b/gquiche/http2/hpack/tools/hpack_example.cc @@ -10,7 +10,6 @@ #include "absl/strings/str_cat.h" #include "gquiche/http2/platform/api/http2_bug_tracker.h" #include "gquiche/http2/platform/api/http2_logging.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" namespace http2 { namespace test { diff --git a/gquiche/http2/hpack/varint/hpack_varint_decoder_test.cc b/gquiche/http2/hpack/varint/hpack_varint_decoder_test.cc index 928c8bfc..6340699d 100644 --- a/gquiche/http2/hpack/varint/hpack_varint_decoder_test.cc +++ b/gquiche/http2/hpack/varint/hpack_varint_decoder_test.cc @@ -12,7 +12,6 @@ #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "gquiche/http2/platform/api/http2_logging.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/http2/tools/random_decoder_test.h" #include "gquiche/common/platform/api/quiche_test.h" diff --git a/gquiche/http2/hpack/varint/hpack_varint_encoder_test.cc b/gquiche/http2/hpack/varint/hpack_varint_encoder_test.cc index e0acd0d3..b2193c0b 100644 --- a/gquiche/http2/hpack/varint/hpack_varint_encoder_test.cc +++ b/gquiche/http2/hpack/varint/hpack_varint_encoder_test.cc @@ -6,7 +6,6 @@ #include "absl/base/macros.h" #include "absl/strings/escaping.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/common/platform/api/quiche_test.h" namespace http2 { diff --git a/gquiche/http2/hpack/varint/hpack_varint_round_trip_test.cc b/gquiche/http2/hpack/varint/hpack_varint_round_trip_test.cc index e99d16dc..0242124f 100644 --- a/gquiche/http2/hpack/varint/hpack_varint_round_trip_test.cc +++ b/gquiche/http2/hpack/varint/hpack_varint_round_trip_test.cc @@ -18,9 +18,9 @@ #include "absl/strings/string_view.h" #include "gquiche/http2/hpack/tools/hpack_block_builder.h" #include "gquiche/http2/platform/api/http2_logging.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "gquiche/http2/tools/random_decoder_test.h" #include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/common/quiche_text_utils.h" using ::testing::AssertionFailure; using ::testing::AssertionSuccess; @@ -166,7 +166,7 @@ class HpackVarintRoundTripTest : public RandomDecoderTest { std::string msg = absl::StrCat("value=", value, " (0x", absl::Hex(value), "), prefix_length=", prefix_length, ", expected_bytes=", expected_bytes, "\n", - Http2HexDump(buffer_)); + quiche::QuicheTextUtils::HexDump(buffer_)); if (value == minimum) { HTTP2_LOG(INFO) << "Checking minimum; " << msg; @@ -221,7 +221,8 @@ class HpackVarintRoundTripTest : public RandomDecoderTest { if (expected_bytes < 11) { // Confirm the claim that beyond requires more bytes. Encode(beyond, prefix_length); - EXPECT_EQ(expected_bytes + 1, buffer_.size()) << Http2HexDump(buffer_); + EXPECT_EQ(expected_bytes + 1, buffer_.size()) + << quiche::QuicheTextUtils::HexDump(buffer_); } std::set values; @@ -285,9 +286,9 @@ TEST_F(HpackVarintRoundTripTest, Encode) { for (uint64_t value : values) { EncodeNoRandom(value, prefix_length); - std::string dump = Http2HexDump(buffer_); + std::string dump = quiche::QuicheTextUtils::HexDump(buffer_); HTTP2_LOG(INFO) << absl::StrFormat("%10llu %0#18x ", value, value) - << Http2HexDump(buffer_).substr(7); + << quiche::QuicheTextUtils::HexDump(buffer_).substr(7); } } } diff --git a/gquiche/http2/http2_constants.cc b/gquiche/http2/http2_constants.cc index 32b5bbde..2acb13f6 100644 --- a/gquiche/http2/http2_constants.cc +++ b/gquiche/http2/http2_constants.cc @@ -7,8 +7,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "gquiche/http2/platform/api/http2_flag_utils.h" #include "gquiche/http2/platform/api/http2_logging.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" namespace http2 { @@ -151,4 +151,30 @@ std::string Http2SettingsParameterToString(Http2SettingsParameter v) { return Http2SettingsParameterToString(static_cast(v)); } +// Invalid HTTP/2 header names according to +// https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.2.2. +// TODO(b/78024822): Consider adding "upgrade" to this set. +constexpr char const* kHttp2InvalidHeaderNames[] = { + "connection", "host", "keep-alive", "proxy-connection", + "transfer-encoding", "", +}; + +constexpr char const* kHttp2InvalidHeaderNamesOld[] = { + "connection", "host", "keep-alive", "proxy-connection", "transfer-encoding", +}; + +const InvalidHeaderSet& GetInvalidHttp2HeaderSet() { + if (!GetQuicheReloadableFlag(quic, quic_verify_request_headers_2)) { + static const auto* invalid_header_set_old = + new InvalidHeaderSet(std::begin(http2::kHttp2InvalidHeaderNamesOld), + std::end(http2::kHttp2InvalidHeaderNamesOld)); + return *invalid_header_set_old; + } + HTTP2_RELOADABLE_FLAG_COUNT_N(quic_verify_request_headers_2, 3, 3); + static const auto* invalid_header_set = + new InvalidHeaderSet(std::begin(http2::kHttp2InvalidHeaderNames), + std::end(http2::kHttp2InvalidHeaderNames)); + return *invalid_header_set; +} + } // namespace http2 diff --git a/gquiche/http2/http2_constants.h b/gquiche/http2/http2_constants.h index 42fbe761..fe186c26 100644 --- a/gquiche/http2/http2_constants.h +++ b/gquiche/http2/http2_constants.h @@ -12,20 +12,19 @@ #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "gquiche/http2/platform/api/http2_flags.h" #include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/common/quiche_text_utils.h" namespace http2 { // TODO(jamessynge): create http2_simple_types for types similar to // SpdyStreamId, but not for structures like Http2FrameHeader. Then will be // able to move these stream id functions there. -constexpr uint32_t UInt31Mask() { - return 0x7fffffff; -} -constexpr uint32_t StreamIdMask() { - return UInt31Mask(); -} +constexpr uint32_t UInt31Mask() { return 0x7fffffff; } +constexpr uint32_t StreamIdMask() { return UInt31Mask(); } // The value used to identify types of frames. Upper case to match the RFC. // The comments indicate which flags are valid for that frame type. @@ -234,7 +233,7 @@ inline std::ostream& operator<<(std::ostream& out, Http2SettingsParameter v) { // Information about the initial, minimum and maximum value of settings (not // applicable to all settings parameters). -class Http2SettingsInfo { +class QUICHE_EXPORT_PRIVATE Http2SettingsInfo { public: // Default value for HEADER_TABLE_SIZE. static constexpr uint32_t DefaultHeaderTableSize() { return 4096; } @@ -259,6 +258,15 @@ class Http2SettingsInfo { static constexpr uint32_t MaximumMaxFrameSize() { return (1 << 24) - 1; } }; +// Http3 early fails upper case request headers, but Http2 still needs case +// insensitive comparison. +using InvalidHeaderSet = + absl::flat_hash_set; + +// Returns all disallowed HTTP/2 headers. +QUICHE_EXPORT_PRIVATE const InvalidHeaderSet& GetInvalidHttp2HeaderSet(); + } // namespace http2 #endif // QUICHE_HTTP2_HTTP2_CONSTANTS_H_ diff --git a/gquiche/http2/http2_structures.cc b/gquiche/http2/http2_structures.cc index 744ee5cf..b435c447 100644 --- a/gquiche/http2/http2_structures.cc +++ b/gquiche/http2/http2_structures.cc @@ -9,7 +9,6 @@ #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" namespace http2 { diff --git a/gquiche/http2/http2_structures_test.cc b/gquiche/http2/http2_structures_test.cc index 5ba19544..0bf0c603 100644 --- a/gquiche/http2/http2_structures_test.cc +++ b/gquiche/http2/http2_structures_test.cc @@ -9,7 +9,7 @@ // Note that EXPECT.*DEATH tests are slow (a fork is probably involved). // And in case you're wondering, yes, these are ridiculously thorough tests, -// but believe it or not, I've found stupid bugs this way. +// but believe it or not, I've found silly bugs this way. #include #include diff --git a/gquiche/http2/platform/api/http2_bug_tracker.h b/gquiche/http2/platform/api/http2_bug_tracker.h index b5b7bff1..945401c2 100644 --- a/gquiche/http2/platform/api/http2_bug_tracker.h +++ b/gquiche/http2/platform/api/http2_bug_tracker.h @@ -5,15 +5,15 @@ #ifndef QUICHE_HTTP2_PLATFORM_API_HTTP2_BUG_TRACKER_H_ #define QUICHE_HTTP2_PLATFORM_API_HTTP2_BUG_TRACKER_H_ -#include "platform/http2_platform_impl/http2_bug_tracker_impl.h" +#include "gquiche/common/platform/api/quiche_bug_tracker.h" -#define HTTP2_BUG(x) HTTP2_BUG_IMPL(x) -#define HTTP2_BUG_IF HTTP2_BUG_IF_IMPL +#define HTTP2_BUG QUICHE_BUG +#define HTTP2_BUG_IF QUICHE_BUG_IF // V2 macros are the same as all the HTTP2_BUG flavor above, but they take a // bug_id parameter. -#define HTTP2_BUG_V2 HTTP2_BUG_V2_IMPL -#define HTTP2_BUG_IF_V2 HTTP2_BUG_IF_V2_IMPL +#define HTTP2_BUG_V2 QUICHE_BUG +#define HTTP2_BUG_IF_V2 QUICHE_BUG_IF #define FLAGS_http2_always_log_bugs_for_tests \ FLAGS_http2_always_log_bugs_for_tests_IMPL diff --git a/gquiche/http2/platform/api/http2_flag_utils.h b/gquiche/http2/platform/api/http2_flag_utils.h index 404dc4ef..900cf542 100644 --- a/gquiche/http2/platform/api/http2_flag_utils.h +++ b/gquiche/http2/platform/api/http2_flag_utils.h @@ -8,7 +8,7 @@ #include "platform/http2_platform_impl/http2_flag_utils_impl.h" #define HTTP2_RELOADABLE_FLAG_COUNT HTTP2_RELOADABLE_FLAG_COUNT_IMPL -#define HTTP2_RELOADABLE_FLAG_COUNT_N HTTP2_RELOADABLE_FLAG_COUNT_N_IMPL +#define HTTP2_RELOADABLE_FLAG_COUNT_N(x,y,z) HTTP2_RELOADABLE_FLAG_COUNT_N_IMPL(x,y,z) #define HTTP2_RESTART_FLAG_COUNT HTTP2_RESTART_FLAG_COUNT_IMPL #define HTTP2_RESTART_FLAG_COUNT_N HTTP2_RESTART_FLAG_COUNT_N_IMPL diff --git a/gquiche/http2/platform/api/http2_flags.h b/gquiche/http2/platform/api/http2_flags.h index cfe387ea..274e9436 100644 --- a/gquiche/http2/platform/api/http2_flags.h +++ b/gquiche/http2/platform/api/http2_flags.h @@ -5,7 +5,7 @@ #ifndef QUICHE_HTTP2_PLATFORM_API_HTTP2_FLAGS_H_ #define QUICHE_HTTP2_PLATFORM_API_HTTP2_FLAGS_H_ -#include "platform/http2_platform_impl/http2_flags_impl.h" +#include "gquiche/common/platform/api/quiche_flags.h" #define GetHttp2ReloadableFlag(flag) GetQuicheReloadableFlag(http2, flag) #define SetHttp2ReloadableFlag(flag, value) \ diff --git a/gquiche/http2/platform/api/http2_logging.h b/gquiche/http2/platform/api/http2_logging.h index 7a322c40..35d8d3bf 100644 --- a/gquiche/http2/platform/api/http2_logging.h +++ b/gquiche/http2/platform/api/http2_logging.h @@ -1,7 +1,7 @@ #ifndef QUICHE_HTTP2_PLATFORM_API_HTTP2_LOGGING_H_ #define QUICHE_HTTP2_PLATFORM_API_HTTP2_LOGGING_H_ -#include "platform/http2_platform_impl/http2_logging_impl.h" +#include "gquiche/common/platform/api/quiche_logging.h" #define HTTP2_LOG(severity) QUICHE_LOG(severity) diff --git a/gquiche/http2/test_tools/frame_parts.cc b/gquiche/http2/test_tools/frame_parts.cc index 1d9f2aef..d4305223 100644 --- a/gquiche/http2/test_tools/frame_parts.cc +++ b/gquiche/http2/test_tools/frame_parts.cc @@ -126,7 +126,7 @@ void FrameParts::SetAltSvcExpected(absl::string_view origin, opt_altsvc_value_length_ = value.size(); } -bool FrameParts::OnFrameHeader(const Http2FrameHeader& header) { +bool FrameParts::OnFrameHeader(const Http2FrameHeader& /*header*/) { ADD_FAILURE() << "OnFrameHeader: " << *this; return true; } diff --git a/gquiche/http2/test_tools/http2_random.cc b/gquiche/http2/test_tools/http2_random.cc index 3d59acfe..dacb38c8 100644 --- a/gquiche/http2/test_tools/http2_random.cc +++ b/gquiche/http2/test_tools/http2_random.cc @@ -2,7 +2,6 @@ #include "absl/strings/escaping.h" #include "gquiche/http2/platform/api/http2_logging.h" -#include "gquiche/http2/platform/api/http2_string_utils.h" #include "openssl/chacha.h" #include "openssl/rand.h" diff --git a/gquiche/quic/core/batch_writer/quic_batch_writer_base.cc b/gquiche/quic/core/batch_writer/quic_batch_writer_base.cc index 20bbb5c7..d84bc349 100644 --- a/gquiche/quic/core/batch_writer/quic_batch_writer_base.cc +++ b/gquiche/quic/core/batch_writer/quic_batch_writer_base.cc @@ -24,20 +24,10 @@ WriteResult QuicBatchWriterBase::WritePacket( const WriteResult result = InternalWritePacket(buffer, buf_len, self_address, peer_address, options); - if (GetQuicReloadableFlag(quic_batch_writer_fix_write_blocked)) { - if (IsWriteBlockedStatus(result.status)) { - if (result.status == WRITE_STATUS_BLOCKED_DATA_BUFFERED) { - QUIC_CODE_COUNT(quic_batch_writer_fix_write_blocked_data_buffered); - } else { - QUIC_CODE_COUNT(quic_batch_writer_fix_write_blocked_data_not_buffered); - } - write_blocked_ = true; - } - } else { - if (result.status == WRITE_STATUS_BLOCKED) { - write_blocked_ = true; - } + if (IsWriteBlockedStatus(result.status)) { + write_blocked_ = true; } + return result; } diff --git a/gquiche/quic/core/batch_writer/quic_batch_writer_base.h b/gquiche/quic/core/batch_writer/quic_batch_writer_base.h index aa8d8d4b..665a0e85 100644 --- a/gquiche/quic/core/batch_writer/quic_batch_writer_base.h +++ b/gquiche/quic/core/batch_writer/quic_batch_writer_base.h @@ -60,7 +60,7 @@ class QUIC_EXPORT_PRIVATE QuicBatchWriterBase : public QuicPacketWriter { const QuicBatchWriterBuffer& batch_buffer() const { return *batch_buffer_; } QuicBatchWriterBuffer& batch_buffer() { return *batch_buffer_; } - const QuicCircularDeque& buffered_writes() const { + const quiche::QuicheCircularDeque& buffered_writes() const { return batch_buffer_->buffered_writes(); } diff --git a/gquiche/quic/core/batch_writer/quic_batch_writer_buffer.h b/gquiche/quic/core/batch_writer/quic_batch_writer_buffer.h index 8d0a5281..5480d8d8 100644 --- a/gquiche/quic/core/batch_writer/quic_batch_writer_buffer.h +++ b/gquiche/quic/core/batch_writer/quic_batch_writer_buffer.h @@ -6,11 +6,11 @@ #define QUICHE_QUIC_PLATFORM_IMPL_BATCH_WRITER_QUIC_BATCH_WRITER_BUFFER_H_ #include "absl/base/optimization.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_linux_socket_utils.h" #include "gquiche/quic/core/quic_packet_writer.h" #include "gquiche/quic/platform/api/quic_ip_address.h" #include "gquiche/quic/platform/api/quic_socket_address.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -61,7 +61,7 @@ class QUIC_EXPORT_PRIVATE QuicBatchWriterBuffer { }; PopResult PopBufferedWrite(int32_t num_buffered_writes); - const QuicCircularDeque& buffered_writes() const { + const quiche::QuicheCircularDeque& buffered_writes() const { return buffered_writes_; } @@ -87,7 +87,7 @@ class QUIC_EXPORT_PRIVATE QuicBatchWriterBuffer { bool Invariants() const; const char* buffer_end() const { return buffer_ + sizeof(buffer_); } ABSL_CACHELINE_ALIGNED char buffer_[kBufferSize]; - QuicCircularDeque buffered_writes_; + quiche::QuicheCircularDeque buffered_writes_; }; } // namespace quic diff --git a/gquiche/quic/core/batch_writer/quic_batch_writer_test.h b/gquiche/quic/core/batch_writer/quic_batch_writer_test.h index 76617ace..21b6ee79 100644 --- a/gquiche/quic/core/batch_writer/quic_batch_writer_test.h +++ b/gquiche/quic/core/batch_writer/quic_batch_writer_test.h @@ -12,6 +12,7 @@ #include #include +#include "absl/base/optimization.h" #include "gquiche/quic/core/batch_writer/quic_batch_writer_base.h" #include "gquiche/quic/core/quic_udp_socket.h" #include "gquiche/quic/platform/api/quic_test.h" @@ -259,8 +260,9 @@ class QUIC_EXPORT_PRIVATE QuicUdpBatchWriterIOTest QuicSocketAddress self_address_; QuicSocketAddress peer_address_; - char packet_buffer_[1500]; - char control_buffer_[kDefaultUdpPacketControlBufferSize]; + ABSL_CACHELINE_ALIGNED char packet_buffer_[1500]; + ABSL_CACHELINE_ALIGNED char + control_buffer_[kDefaultUdpPacketControlBufferSize]; int address_family_; const size_t data_size_; const size_t packet_size_; diff --git a/gquiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc b/gquiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc index 1987b537..761825f7 100644 --- a/gquiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc +++ b/gquiche/quic/core/batch_writer/quic_gso_batch_writer_test.cc @@ -322,11 +322,9 @@ TEST_F(QuicGsoBatchWriterTest, WriteBlockDataBuffered) { })); ASSERT_EQ(WriteResult(WRITE_STATUS_BLOCKED_DATA_BUFFERED, EWOULDBLOCK), WritePacket(&writer, 50)); - if (GetQuicReloadableFlag(quic_batch_writer_fix_write_blocked)) { - EXPECT_TRUE(writer.IsWriteBlocked()); - } else { - EXPECT_FALSE(writer.IsWriteBlocked()); - } + + EXPECT_TRUE(writer.IsWriteBlocked()); + ASSERT_EQ(250u, writer.batch_buffer().SizeInUse()); ASSERT_EQ(3u, writer.buffered_writes().size()); } diff --git a/gquiche/quic/core/chlo_extractor.cc b/gquiche/quic/core/chlo_extractor.cc index eb4e2914..8dc6898f 100644 --- a/gquiche/quic/core/chlo_extractor.cc +++ b/gquiche/quic/core/chlo_extractor.cc @@ -17,7 +17,6 @@ #include "gquiche/quic/core/quic_framer.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { diff --git a/gquiche/quic/core/congestion_control/bandwidth_sampler.cc b/gquiche/quic/core/congestion_control/bandwidth_sampler.cc index 9c5f3fdb..fdab53b1 100644 --- a/gquiche/quic/core/congestion_control/bandwidth_sampler.cc +++ b/gquiche/quic/core/congestion_control/bandwidth_sampler.cc @@ -23,21 +23,70 @@ std::ostream& operator<<(std::ostream& os, const SendTimeState& s) { return os; } -QuicByteCount MaxAckHeightTracker::Update(QuicBandwidth bandwidth_estimate, - QuicRoundTripCount round_trip_count, - QuicTime ack_time, - QuicByteCount bytes_acked) { - if (aggregation_epoch_start_time_ == QuicTime::Zero()) { +QuicByteCount MaxAckHeightTracker::Update( + QuicBandwidth bandwidth_estimate, bool is_new_max_bandwidth, + QuicRoundTripCount round_trip_count, + QuicPacketNumber last_sent_packet_number, + QuicPacketNumber last_acked_packet_number, QuicTime ack_time, + QuicByteCount bytes_acked) { + bool force_new_epoch = false; + + if (reduce_extra_acked_on_bandwidth_increase_ && is_new_max_bandwidth) { + // Save and clear existing entries. + ExtraAckedEvent best = max_ack_height_filter_.GetBest(); + ExtraAckedEvent second_best = max_ack_height_filter_.GetSecondBest(); + ExtraAckedEvent third_best = max_ack_height_filter_.GetThirdBest(); + max_ack_height_filter_.Clear(); + + // Reinsert the heights into the filter after recalculating. + QuicByteCount expected_bytes_acked = bandwidth_estimate * best.time_delta; + if (expected_bytes_acked < best.bytes_acked) { + best.extra_acked = best.bytes_acked - expected_bytes_acked; + max_ack_height_filter_.Update(best, best.round); + } + expected_bytes_acked = bandwidth_estimate * second_best.time_delta; + if (expected_bytes_acked < second_best.bytes_acked) { + QUICHE_DCHECK_LE(best.round, second_best.round); + second_best.extra_acked = second_best.bytes_acked - expected_bytes_acked; + max_ack_height_filter_.Update(second_best, second_best.round); + } + expected_bytes_acked = bandwidth_estimate * third_best.time_delta; + if (expected_bytes_acked < third_best.bytes_acked) { + QUICHE_DCHECK_LE(second_best.round, third_best.round); + third_best.extra_acked = third_best.bytes_acked - expected_bytes_acked; + max_ack_height_filter_.Update(third_best, third_best.round); + } + } + + // If any packet sent after the start of the epoch has been acked, start a new + // epoch. + if (start_new_aggregation_epoch_after_full_round_ && + last_sent_packet_number_before_epoch_.IsInitialized() && + last_acked_packet_number.IsInitialized() && + last_acked_packet_number > last_sent_packet_number_before_epoch_) { + QUIC_DVLOG(3) << "Force starting a new aggregation epoch. " + "last_sent_packet_number_before_epoch_:" + << last_sent_packet_number_before_epoch_ + << ", last_acked_packet_number:" << last_acked_packet_number; + if (reduce_extra_acked_on_bandwidth_increase_) { + QUIC_BUG(quic_bwsampler_46) + << "A full round of aggregation should never " + << "pass with startup_include_extra_acked(B204) enabled."; + } + force_new_epoch = true; + } + if (aggregation_epoch_start_time_ == QuicTime::Zero() || force_new_epoch) { aggregation_epoch_bytes_ = bytes_acked; aggregation_epoch_start_time_ = ack_time; + last_sent_packet_number_before_epoch_ = last_sent_packet_number; ++num_ack_aggregation_epochs_; return 0; } // Compute how many bytes are expected to be delivered, assuming max bandwidth // is correct. - QuicByteCount expected_bytes_acked = - bandwidth_estimate * (ack_time - aggregation_epoch_start_time_); + QuicTime::Delta aggregation_delta = ack_time - aggregation_epoch_start_time_; + QuicByteCount expected_bytes_acked = bandwidth_estimate * aggregation_delta; // Reset the current aggregation epoch as soon as the ack arrival rate is less // than or equal to the max bandwidth. if (aggregation_epoch_bytes_ <= @@ -50,13 +99,13 @@ QuicByteCount MaxAckHeightTracker::Update(QuicBandwidth bandwidth_estimate, << ack_aggregation_bandwidth_threshold_ << ", expected_bytes_acked:" << expected_bytes_acked << ", bandwidth_estimate:" << bandwidth_estimate - << ", aggregation_duration:" - << (ack_time - aggregation_epoch_start_time_) + << ", aggregation_duration:" << aggregation_delta << ", new_aggregation_epoch:" << ack_time << ", new_aggregation_bytes_acked:" << bytes_acked; // Reset to start measuring a new aggregation epoch. aggregation_epoch_bytes_ = bytes_acked; aggregation_epoch_start_time_ = ack_time; + last_sent_packet_number_before_epoch_ = last_sent_packet_number; ++num_ack_aggregation_epochs_; return 0; } @@ -67,13 +116,17 @@ QuicByteCount MaxAckHeightTracker::Update(QuicBandwidth bandwidth_estimate, QuicByteCount extra_bytes_acked = aggregation_epoch_bytes_ - expected_bytes_acked; QUIC_DVLOG(3) << "Updating MaxAckHeight. ack_time:" << ack_time - << ", round trip count:" << round_trip_count + << ", last sent packet:" << last_sent_packet_number << ", bandwidth_estimate:" << bandwidth_estimate << ", bytes_acked:" << bytes_acked << ", expected_bytes_acked:" << expected_bytes_acked << ", aggregation_epoch_bytes_:" << aggregation_epoch_bytes_ << ", extra_bytes_acked:" << extra_bytes_acked; - max_ack_height_filter_.Update(extra_bytes_acked, round_trip_count); + ExtraAckedEvent new_event; + new_event.extra_acked = extra_bytes_acked; + new_event.bytes_acked = aggregation_epoch_bytes_; + new_event.time_delta = aggregation_delta; + max_ack_height_filter_.Update(new_event, round_trip_count); return extra_bytes_acked; } @@ -93,7 +146,8 @@ BandwidthSampler::BandwidthSampler( unacked_packet_map_(unacked_packet_map), max_ack_height_tracker_(max_height_tracker_window_length), total_bytes_acked_after_last_ack_event_(0), - overestimate_avoidance_(false) {} + overestimate_avoidance_(false), + limit_max_ack_height_tracker_by_send_rate_(false) {} BandwidthSampler::BandwidthSampler(const BandwidthSampler& other) : total_bytes_sent_(other.total_bytes_sent_), @@ -105,6 +159,7 @@ BandwidthSampler::BandwidthSampler(const BandwidthSampler& other) last_acked_packet_sent_time_(other.last_acked_packet_sent_time_), last_acked_packet_ack_time_(other.last_acked_packet_ack_time_), last_sent_packet_(other.last_sent_packet_), + last_acked_packet_(other.last_acked_packet_), is_app_limited_(other.is_app_limited_), end_of_app_limited_phase_(other.end_of_app_limited_phase_), connection_state_map_(other.connection_state_map_), @@ -115,7 +170,9 @@ BandwidthSampler::BandwidthSampler(const BandwidthSampler& other) max_ack_height_tracker_(other.max_ack_height_tracker_), total_bytes_acked_after_last_ack_event_( other.total_bytes_acked_after_last_ack_event_), - overestimate_avoidance_(other.overestimate_avoidance_) {} + overestimate_avoidance_(other.overestimate_avoidance_), + limit_max_ack_height_tracker_by_send_rate_( + other.limit_max_ack_height_tracker_by_send_rate_) {} void BandwidthSampler::EnableOverestimateAvoidance() { if (overestimate_avoidance_) { @@ -244,6 +301,7 @@ BandwidthSampler::OnCongestionEvent(QuicTime ack_time, } SendTimeState last_acked_packet_send_state; + QuicBandwidth max_send_rate = QuicBandwidth::Zero(); for (const auto& packet : acked_packets) { BandwidthSample sample = OnPacketAcknowledged(ack_time, packet.packet_number); @@ -260,6 +318,9 @@ BandwidthSampler::OnCongestionEvent(QuicTime ack_time, event_sample.sample_max_bandwidth = sample.bandwidth; event_sample.sample_is_app_limited = sample.state_at_send.is_app_limited; } + if (!sample.send_rate.IsInfinite()) { + max_send_rate = std::max(max_send_rate, sample.send_rate); + } const QuicByteCount inflight_sample = total_bytes_acked() - last_acked_packet_send_state.total_bytes_acked; if (inflight_sample > event_sample.sample_max_inflight) { @@ -282,15 +343,21 @@ BandwidthSampler::OnCongestionEvent(QuicTime ack_time, : last_acked_packet_send_state; } + bool is_new_max_bandwidth = event_sample.sample_max_bandwidth > max_bandwidth; max_bandwidth = std::max(max_bandwidth, event_sample.sample_max_bandwidth); - event_sample.extra_acked = OnAckEventEnd( - std::min(est_bandwidth_upper_bound, max_bandwidth), round_trip_count); + if (limit_max_ack_height_tracker_by_send_rate_) { + max_bandwidth = std::max(max_bandwidth, max_send_rate); + } + // TODO(ianswett): Why is the min being passed in here? + event_sample.extra_acked = + OnAckEventEnd(std::min(est_bandwidth_upper_bound, max_bandwidth), + is_new_max_bandwidth, round_trip_count); return event_sample; } QuicByteCount BandwidthSampler::OnAckEventEnd( - QuicBandwidth bandwidth_estimate, + QuicBandwidth bandwidth_estimate, bool is_new_max_bandwidth, QuicRoundTripCount round_trip_count) { const QuicByteCount newly_acked_bytes = total_bytes_acked_ - total_bytes_acked_after_last_ack_event_; @@ -299,9 +366,9 @@ QuicByteCount BandwidthSampler::OnAckEventEnd( return 0; } total_bytes_acked_after_last_ack_event_ = total_bytes_acked_; - QuicByteCount extra_acked = max_ack_height_tracker_.Update( - bandwidth_estimate, round_trip_count, last_acked_packet_ack_time_, + bandwidth_estimate, is_new_max_bandwidth, round_trip_count, + last_sent_packet_, last_acked_packet_, last_acked_packet_ack_time_, newly_acked_bytes); // If |extra_acked| is zero, i.e. this ack event marks the start of a new ack // aggregation epoch, save LessRecentPoint, which is the last ack point of the @@ -316,6 +383,7 @@ QuicByteCount BandwidthSampler::OnAckEventEnd( BandwidthSample BandwidthSampler::OnPacketAcknowledged( QuicTime ack_time, QuicPacketNumber packet_number) { + last_acked_packet_ = packet_number; ConnectionStateOnSentPacket* sent_packet_pointer = connection_state_map_.GetEntry(packet_number); if (sent_packet_pointer == nullptr) { @@ -412,6 +480,7 @@ BandwidthSample BandwidthSampler::OnPacketAcknowledgedInner( // means that the RTT measurements here can be artificially high, especially // on low bandwidth connections. sample.rtt = ack_time - sent_packet.sent_time; + sample.send_rate = send_rate; SentPacketToSendTimeState(sent_packet, &sample.state_at_send); if (sample.bandwidth.IsZero()) { diff --git a/gquiche/quic/core/congestion_control/bandwidth_sampler.h b/gquiche/quic/core/congestion_control/bandwidth_sampler.h index b83d48d6..ac6d9a51 100644 --- a/gquiche/quic/core/congestion_control/bandwidth_sampler.h +++ b/gquiche/quic/core/congestion_control/bandwidth_sampler.h @@ -9,13 +9,14 @@ #include "gquiche/quic/core/congestion_control/windowed_filter.h" #include "gquiche/quic/core/packet_number_indexed_queue.h" #include "gquiche/quic/core/quic_bandwidth.h" -#include "gquiche/quic/core/quic_circular_deque.h" +#include "gquiche/quic/core/quic_packet_number.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_unacked_packet_map.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_flags.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -47,6 +48,7 @@ struct QUIC_EXPORT_PRIVATE SendTimeState { bytes_in_flight(bytes_in_flight) {} SendTimeState(const SendTimeState& other) = default; + SendTimeState& operator=(const SendTimeState& other) = default; friend QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, const SendTimeState& s); @@ -76,20 +78,39 @@ struct QUIC_EXPORT_PRIVATE SendTimeState { QuicByteCount bytes_in_flight; }; +struct QUIC_NO_EXPORT ExtraAckedEvent { + // The excess bytes acknowlwedged in the time delta for this event. + QuicByteCount extra_acked = 0; + + // The bytes acknowledged and time delta from the event. + QuicByteCount bytes_acked = 0; + QuicTime::Delta time_delta = QuicTime::Delta::Zero(); + // The round trip of the event. + QuicRoundTripCount round = 0; + + inline bool operator>=(const ExtraAckedEvent& other) const { + return extra_acked >= other.extra_acked; + } + inline bool operator==(const ExtraAckedEvent& other) const { + return extra_acked == other.extra_acked; + } +}; + struct QUIC_EXPORT_PRIVATE BandwidthSample { // The bandwidth at that particular sample. Zero if no valid bandwidth sample // is available. - QuicBandwidth bandwidth; + QuicBandwidth bandwidth = QuicBandwidth::Zero(); // The RTT measurement at this particular sample. Zero if no RTT sample is // available. Does not correct for delayed ack time. - QuicTime::Delta rtt; + QuicTime::Delta rtt = QuicTime::Delta::Zero(); + + // |send_rate| is computed from the current packet being acked('P') and an + // earlier packet that is acked before P was sent. + QuicBandwidth send_rate = QuicBandwidth::Infinite(); // States captured when the packet was sent. SendTimeState state_at_send; - - BandwidthSample() - : bandwidth(QuicBandwidth::Zero()), rtt(QuicTime::Delta::Zero()) {} }; // MaxAckHeightTracker is part of the BandwidthSampler. It is called after every @@ -97,27 +118,42 @@ struct QUIC_EXPORT_PRIVATE BandwidthSample { class QUIC_EXPORT_PRIVATE MaxAckHeightTracker { public: explicit MaxAckHeightTracker(QuicRoundTripCount initial_filter_window) - : max_ack_height_filter_(initial_filter_window, 0, 0) {} + : max_ack_height_filter_(initial_filter_window, ExtraAckedEvent(), 0) {} - QuicByteCount Get() const { return max_ack_height_filter_.GetBest(); } + QuicByteCount Get() const { + return max_ack_height_filter_.GetBest().extra_acked; + } QuicByteCount Update(QuicBandwidth bandwidth_estimate, + bool is_new_max_bandwidth, QuicRoundTripCount round_trip_count, - QuicTime ack_time, - QuicByteCount bytes_acked); + QuicPacketNumber last_sent_packet_number, + QuicPacketNumber last_acked_packet_number, + QuicTime ack_time, QuicByteCount bytes_acked); void SetFilterWindowLength(QuicRoundTripCount length) { max_ack_height_filter_.SetWindowLength(length); } void Reset(QuicByteCount new_height, QuicRoundTripCount new_time) { - max_ack_height_filter_.Reset(new_height, new_time); + ExtraAckedEvent new_event; + new_event.extra_acked = new_height; + new_event.round = new_time; + max_ack_height_filter_.Reset(new_event, new_time); } void SetAckAggregationBandwidthThreshold(double threshold) { ack_aggregation_bandwidth_threshold_ = threshold; } + void SetStartNewAggregationEpochAfterFullRound(bool value) { + start_new_aggregation_epoch_after_full_round_ = value; + } + + void SetReduceExtraAckedOnBandwidthIncrease(bool value) { + reduce_extra_acked_on_bandwidth_increase_ = value; + } + double ack_aggregation_bandwidth_threshold() const { return ack_aggregation_bandwidth_threshold_; } @@ -129,20 +165,23 @@ class QUIC_EXPORT_PRIVATE MaxAckHeightTracker { private: // Tracks the maximum number of bytes acked faster than the estimated // bandwidth. - using MaxAckHeightFilter = WindowedFilter, - QuicRoundTripCount, - QuicRoundTripCount>; + using MaxAckHeightFilter = + WindowedFilter, + QuicRoundTripCount, QuicRoundTripCount>; MaxAckHeightFilter max_ack_height_filter_; // The time this aggregation started and the number of bytes acked during it. QuicTime aggregation_epoch_start_time_ = QuicTime::Zero(); QuicByteCount aggregation_epoch_bytes_ = 0; + // The last sent packet number before the current aggregation epoch started. + QuicPacketNumber last_sent_packet_number_before_epoch_; // The number of ack aggregation epochs ever started, including the ongoing // one. Stats only. uint64_t num_ack_aggregation_epochs_ = 0; double ack_aggregation_bandwidth_threshold_ = GetQuicFlag(FLAGS_quic_ack_aggregation_bandwidth_threshold); + bool start_new_aggregation_epoch_after_full_round_ = false; + bool reduce_extra_acked_on_bandwidth_increase_ = false; }; // An interface common to any class that can provide bandwidth samples from the @@ -324,6 +363,7 @@ class QUIC_EXPORT_PRIVATE BandwidthSampler : public BandwidthSamplerInterface { QuicBandwidth est_bandwidth_upper_bound, QuicRoundTripCount round_trip_count) override; QuicByteCount OnAckEventEnd(QuicBandwidth bandwidth_estimate, + bool is_new_max_bandwidth, QuicRoundTripCount round_trip_count); void OnAppLimited() override; @@ -354,6 +394,18 @@ class QUIC_EXPORT_PRIVATE BandwidthSampler : public BandwidthSamplerInterface { max_ack_height_tracker_.Reset(new_height, new_time); } + void SetStartNewAggregationEpochAfterFullRound(bool value) { + max_ack_height_tracker_.SetStartNewAggregationEpochAfterFullRound(value); + } + + void SetLimitMaxAckHeightTrackerBySendRate(bool value) { + limit_max_ack_height_tracker_by_send_rate_ = value; + } + + void SetReduceExtraAckedOnBandwidthIncrease(bool value) { + max_ack_height_tracker_.SetReduceExtraAckedOnBandwidthIncrease(value); + } + // AckPoint represents a point on the ack line. struct QUIC_NO_EXPORT AckPoint { QuicTime ack_time = QuicTime::Zero(); @@ -527,6 +579,9 @@ class QUIC_EXPORT_PRIVATE BandwidthSampler : public BandwidthSamplerInterface { // The most recently sent packet. QuicPacketNumber last_sent_packet_; + // The most recently acked packet. + QuicPacketNumber last_acked_packet_; + // Indicates whether the bandwidth sampler is currently in an app-limited // phase. bool is_app_limited_; @@ -540,7 +595,7 @@ class QUIC_EXPORT_PRIVATE BandwidthSampler : public BandwidthSamplerInterface { PacketNumberIndexedQueue connection_state_map_; RecentAckPoints recent_ack_points_; - QuicCircularDeque a0_candidates_; + quiche::QuicheCircularDeque a0_candidates_; // Maximum number of tracked packets. const QuicPacketCount max_tracked_packets_; @@ -562,6 +617,9 @@ class QUIC_EXPORT_PRIVATE BandwidthSampler : public BandwidthSamplerInterface { // True if connection option 'BSAO' is set. bool overestimate_avoidance_; + + // True if connection option 'BBRB' is set. + bool limit_max_ack_height_tracker_by_send_rate_; }; } // namespace quic diff --git a/gquiche/quic/core/congestion_control/bandwidth_sampler_test.cc b/gquiche/quic/core/congestion_control/bandwidth_sampler_test.cc index e67631aa..cb10fc41 100644 --- a/gquiche/quic/core/congestion_control/bandwidth_sampler_test.cc +++ b/gquiche/quic/core/congestion_control/bandwidth_sampler_test.cc @@ -685,22 +685,23 @@ TEST_P(BandwidthSamplerTest, AckHeightRespectBandwidthEstimateUpperBound) { QuicBandwidth::FromBytesAndTimeDelta(kRegularPacketSize, time_between_packets); - // Send and ack packet 1. + // Send packets 1 to 4 and ack packet 1. SendPacket(1); clock_.AdvanceTime(time_between_packets); + SendPacket(2); + SendPacket(3); + SendPacket(4); BandwidthSampler::CongestionEventSample sample = OnCongestionEvent({1}, {}); EXPECT_EQ(first_packet_sending_rate, sample.sample_max_bandwidth); EXPECT_EQ(first_packet_sending_rate, max_bandwidth_); - // Send and ack packet 2, 3 and 4. + // Ack packet 2, 3 and 4, all of which uses S(1) to calculate ack rate since + // there were no acks at the time they were sent. round_trip_count_++; est_bandwidth_upper_bound_ = first_packet_sending_rate * 0.3; - SendPacket(2); - SendPacket(3); - SendPacket(4); clock_.AdvanceTime(time_between_packets); sample = OnCongestionEvent({2, 3, 4}, {}); - EXPECT_EQ(first_packet_sending_rate * 3, sample.sample_max_bandwidth); + EXPECT_EQ(first_packet_sending_rate * 2, sample.sample_max_bandwidth); EXPECT_EQ(max_bandwidth_, sample.sample_max_bandwidth); EXPECT_LT(2 * kRegularPacketSize, sample.extra_acked); @@ -710,6 +711,7 @@ class MaxAckHeightTrackerTest : public QuicTest { protected: MaxAckHeightTrackerTest() : tracker_(/*initial_filter_window=*/10) { tracker_.SetAckAggregationBandwidthThreshold(1.8); + tracker_.SetStartNewAggregationEpochAfterFullRound(true); } // Run a full aggregation episode, which is one or more aggregated acks, @@ -749,8 +751,9 @@ class MaxAckHeightTrackerTest : public QuicTest { QuicByteCount last_extra_acked = 0; for (QuicByteCount bytes = 0; bytes < aggregation_bytes; bytes += bytes_per_ack) { - QuicByteCount extra_acked = - tracker_.Update(bandwidth_, RoundTripCount(), now_, bytes_per_ack); + QuicByteCount extra_acked = tracker_.Update( + bandwidth_, true, RoundTripCount(), last_sent_packet_number_, + last_acked_packet_number_, now_, bytes_per_ack); QUIC_VLOG(1) << "T" << now_ << ": Update after " << bytes_per_ack << " bytes acked, " << extra_acked << " extra bytes acked"; // |extra_acked| should be 0 if either @@ -784,6 +787,8 @@ class MaxAckHeightTrackerTest : public QuicTest { QuicBandwidth bandwidth_ = QuicBandwidth::FromBytesPerSecond(10 * 1000); QuicTime now_ = QuicTime::Zero() + QuicTime::Delta::FromMilliseconds(1); QuicTime::Delta rtt_ = QuicTime::Delta::FromMilliseconds(60); + QuicPacketNumber last_sent_packet_number_; + QuicPacketNumber last_acked_packet_number_; }; TEST_F(MaxAckHeightTrackerTest, VeryAggregatedLargeAck) { @@ -864,5 +869,21 @@ TEST_F(MaxAckHeightTrackerTest, NotAggregated) { EXPECT_LT(2u, tracker_.num_ack_aggregation_epochs()); } +TEST_F(MaxAckHeightTrackerTest, StartNewEpochAfterAFullRound) { + last_sent_packet_number_ = QuicPacketNumber(10); + AggregationEpisode(bandwidth_ * 2, QuicTime::Delta::FromMilliseconds(50), 100, + true); + + last_acked_packet_number_ = QuicPacketNumber(11); + // Update with a tiny bandwidth causes a very low expected bytes acked, which + // in turn causes the current epoch to continue if the |tracker_| doesn't + // check the packet numbers. + tracker_.Update(bandwidth_ * 0.1, true, RoundTripCount(), + last_sent_packet_number_, last_acked_packet_number_, now_, + 100); + + EXPECT_EQ(2u, tracker_.num_ack_aggregation_epochs()); +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/core/congestion_control/bbr2_misc.cc b/gquiche/quic/core/congestion_control/bbr2_misc.cc index c099b371..e61284a7 100644 --- a/gquiche/quic/core/congestion_control/bbr2_misc.cc +++ b/gquiche/quic/core/congestion_control/bbr2_misc.cc @@ -102,10 +102,13 @@ void Bbr2NetworkModel::OnCongestionEventStart( lost_packets, MaxBandwidth(), bandwidth_lo(), RoundTripCount()); + if (sample.extra_acked == 0) { + cwnd_limited_before_aggregation_epoch_ = + congestion_event->prior_bytes_in_flight >= congestion_event->prior_cwnd; + } + if (sample.last_packet_send_state.is_valid) { congestion_event->last_packet_send_state = sample.last_packet_send_state; - congestion_event->last_sample_is_app_limited = - sample.last_packet_send_state.is_app_limited; } // Avoid updating |max_bandwidth_filter_| if a) this is a loss-only event, or @@ -116,9 +119,9 @@ void Bbr2NetworkModel::OnCongestionEventStart( << total_bytes_acked() - prior_bytes_acked << " bytes from " << acked_packets.size() << " packets have been acked, but sample_max_bandwidth is zero."; + congestion_event->sample_max_bandwidth = sample.sample_max_bandwidth; if (!sample.sample_is_app_limited || sample.sample_max_bandwidth > MaxBandwidth()) { - congestion_event->sample_max_bandwidth = sample.sample_max_bandwidth; max_bandwidth_filter_.Update(congestion_event->sample_max_bandwidth); } } @@ -159,6 +162,12 @@ void Bbr2NetworkModel::OnCongestionEventStart( congestion_event->last_packet_send_state.total_bytes_acked; max_bytes_delivered_in_round_ = std::max(max_bytes_delivered_in_round_, bytes_delivered); + // TODO(ianswett) Consider treating any bytes lost as decreasing inflight, + // because it's a sign of overutilization, not underutilization. + if (min_bytes_in_flight_in_round_ == 0 || + congestion_event->bytes_in_flight < min_bytes_in_flight_in_round_) { + min_bytes_in_flight_in_round_ = congestion_event->bytes_in_flight; + } } // |bandwidth_latest_| and |inflight_latest_| only increased within a round. @@ -188,99 +197,108 @@ void Bbr2NetworkModel::OnCongestionEventStart( void Bbr2NetworkModel::AdaptLowerBounds( const Bbr2CongestionEvent& congestion_event) { - if (Params().bw_lo_mode_ != Bbr2Params::DEFAULT) { - if (congestion_event.bytes_lost == 0) { + if (Params().bw_lo_mode_ == Bbr2Params::DEFAULT) { + if (!congestion_event.end_of_round_trip || + congestion_event.is_probing_for_bandwidth) { return; } - // Ignore losses from packets sent when probing for more bandwidth in - // STARTUP or PROBE_UP when they're lost in DRAIN or PROBE_DOWN. - if (pacing_gain_ < 1) { - return; - } - // Decrease bandwidth_lo whenever there is loss. - // Set bandwidth_lo_ if it is not yet set. - if (bandwidth_lo_.IsInfinite()) { - bandwidth_lo_ = MaxBandwidth(); - } - // Save bandwidth_lo_ if it hasn't already been saved. - if (prior_bandwidth_lo_.IsZero()) { - prior_bandwidth_lo_ = bandwidth_lo_; - } - switch (Params().bw_lo_mode_) { - case Bbr2Params::MIN_RTT_REDUCTION: - bandwidth_lo_ = - bandwidth_lo_ - QuicBandwidth::FromBytesAndTimeDelta( - congestion_event.bytes_lost, MinRtt()); - break; - case Bbr2Params::INFLIGHT_REDUCTION: { - // Use a max of BDP and inflight to avoid starving app-limited flows. - const QuicByteCount effective_inflight = - std::max(BDP(), congestion_event.prior_bytes_in_flight); - // This could use bytes_lost_in_round if the bandwidth_lo_ was saved - // when entering 'recovery', but this BBRv2 implementation doesn't have - // recovery defined. - bandwidth_lo_ = bandwidth_lo_ * - ((effective_inflight - congestion_event.bytes_lost) / - static_cast(effective_inflight)); - break; + + if (bytes_lost_in_round_ > 0) { + if (bandwidth_lo_.IsInfinite()) { + bandwidth_lo_ = MaxBandwidth(); } - case Bbr2Params::CWND_REDUCTION: - bandwidth_lo_ = - bandwidth_lo_ * - ((congestion_event.prior_cwnd - congestion_event.bytes_lost) / - static_cast(congestion_event.prior_cwnd)); - break; - case Bbr2Params::DEFAULT: - QUIC_BUG(quic_bug_10466_1) << "Unreachable case DEFAULT."; - } - if (pacing_gain_ > Params().startup_full_bw_threshold) { - // In STARTUP, pacing_gain_ is applied to bandwidth_lo_ in - // UpdatePacingRate, so this backs that multiplication out to allow the - // pacing rate to decrease, but not below - // bandwidth_latest_ * startup_full_bw_threshold. - bandwidth_lo_ = - std::max(bandwidth_lo_, - bandwidth_latest_ * - (Params().startup_full_bw_threshold / pacing_gain_)); - } else { - // Ensure bandwidth_lo isn't lower than bandwidth_latest_. - bandwidth_lo_ = std::max(bandwidth_lo_, bandwidth_latest_); - } - // If it's the end of the round, ensure bandwidth_lo doesn't decrease more - // than beta. - if (GetQuicReloadableFlag(quic_bbr2_fix_bw_lo_mode) && - congestion_event.end_of_round_trip) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_fix_bw_lo_mode, 2, 2); bandwidth_lo_ = - std::max(bandwidth_lo_, prior_bandwidth_lo_ * (1.0 - Params().beta)); - prior_bandwidth_lo_ = QuicBandwidth::Zero(); + std::max(bandwidth_latest_, bandwidth_lo_ * (1.0 - Params().beta)); + QUIC_DVLOG(3) << "bandwidth_lo_ updated to " << bandwidth_lo_ + << ", bandwidth_latest_ is " << bandwidth_latest_; + + if (Params().ignore_inflight_lo) { + return; + } + if (inflight_lo_ == inflight_lo_default()) { + inflight_lo_ = congestion_event.prior_cwnd; + } + inflight_lo_ = std::max( + inflight_latest_, inflight_lo_ * (1.0 - Params().beta)); } - // This early return ignores inflight_lo as well. return; } - if (!congestion_event.end_of_round_trip || - congestion_event.is_probing_for_bandwidth) { + + // Params().bw_lo_mode_ != Bbr2Params::DEFAULT + if (congestion_event.bytes_lost == 0) { return; } - - if (bytes_lost_in_round_ > 0) { - if (bandwidth_lo_.IsInfinite()) { - bandwidth_lo_ = MaxBandwidth(); + // Ignore losses from packets sent when probing for more bandwidth in + // STARTUP or PROBE_UP when they're lost in DRAIN or PROBE_DOWN. + if (pacing_gain_ < 1) { + return; + } + // Decrease bandwidth_lo whenever there is loss. + // Set bandwidth_lo_ if it is not yet set. + if (bandwidth_lo_.IsInfinite()) { + bandwidth_lo_ = MaxBandwidth(); + } + // Save bandwidth_lo_ if it hasn't already been saved. + if (prior_bandwidth_lo_.IsZero()) { + prior_bandwidth_lo_ = bandwidth_lo_; + } + switch (Params().bw_lo_mode_) { + case Bbr2Params::MIN_RTT_REDUCTION: + bandwidth_lo_ = + bandwidth_lo_ - QuicBandwidth::FromBytesAndTimeDelta( + congestion_event.bytes_lost, MinRtt()); + break; + case Bbr2Params::INFLIGHT_REDUCTION: { + // Use a max of BDP and inflight to avoid starving app-limited flows. + const QuicByteCount effective_inflight = + std::max(BDP(), congestion_event.prior_bytes_in_flight); + // This could use bytes_lost_in_round if the bandwidth_lo_ was saved + // when entering 'recovery', but this BBRv2 implementation doesn't have + // recovery defined. + bandwidth_lo_ = + bandwidth_lo_ * ((effective_inflight - congestion_event.bytes_lost) / + static_cast(effective_inflight)); + break; } + case Bbr2Params::CWND_REDUCTION: + bandwidth_lo_ = + bandwidth_lo_ * + ((congestion_event.prior_cwnd - congestion_event.bytes_lost) / + static_cast(congestion_event.prior_cwnd)); + break; + case Bbr2Params::DEFAULT: + QUIC_BUG(quic_bug_10466_1) << "Unreachable case DEFAULT."; + } + QuicBandwidth last_bandwidth = bandwidth_latest_; + // sample_max_bandwidth will be Zero() if the loss is triggered by a timer + // expiring. Ideally we'd use the most recent bandwidth sample, + // but bandwidth_latest is safer than Zero(). + if (!congestion_event.sample_max_bandwidth.IsZero()) { + // bandwidth_latest_ is the max bandwidth for the round, but to allow + // fast, conservation style response to loss, use the last sample. + last_bandwidth = congestion_event.sample_max_bandwidth; + } + if (pacing_gain_ > Params().startup_full_bw_threshold) { + // In STARTUP, pacing_gain_ is applied to bandwidth_lo_ in + // UpdatePacingRate, so this backs that multiplication out to allow the + // pacing rate to decrease, but not below + // last_bandwidth * startup_full_bw_threshold. + // TODO(ianswett): Consider altering pacing_gain_ when in STARTUP instead. + bandwidth_lo_ = std::max( + bandwidth_lo_, + last_bandwidth * (Params().startup_full_bw_threshold / pacing_gain_)); + } else { + // Ensure bandwidth_lo isn't lower than last_bandwidth. + bandwidth_lo_ = std::max(bandwidth_lo_, last_bandwidth); + } + // If it's the end of the round, ensure bandwidth_lo doesn't decrease more + // than beta. + if (congestion_event.end_of_round_trip) { bandwidth_lo_ = - std::max(bandwidth_latest_, bandwidth_lo_ * (1.0 - Params().beta)); - QUIC_DVLOG(3) << "bandwidth_lo_ updated to " << bandwidth_lo_ - << ", bandwidth_latest_ is " << bandwidth_latest_; - - if (Params().ignore_inflight_lo) { - return; - } - if (inflight_lo_ == inflight_lo_default()) { - inflight_lo_ = congestion_event.prior_cwnd; - } - inflight_lo_ = std::max( - inflight_latest_, inflight_lo_ * (1.0 - Params().beta)); + std::max(bandwidth_lo_, prior_bandwidth_lo_ * (1.0 - Params().beta)); + prior_bandwidth_lo_ = QuicBandwidth::Zero(); } + // These modes ignore inflight_lo as well. } void Bbr2NetworkModel::OnCongestionEventFinish( @@ -316,14 +334,6 @@ bool Bbr2NetworkModel::MaybeExpireMinRtt( return true; } -bool Bbr2NetworkModel::IsCongestionWindowLimited( - const Bbr2CongestionEvent& congestion_event) const { - QuicByteCount prior_bytes_in_flight = congestion_event.bytes_in_flight + - congestion_event.bytes_acked + - congestion_event.bytes_lost; - return prior_bytes_in_flight >= congestion_event.prior_cwnd; -} - bool Bbr2NetworkModel::IsInflightTooHigh( const Bbr2CongestionEvent& congestion_event, int64_t max_loss_events) const { @@ -370,6 +380,7 @@ void Bbr2NetworkModel::OnNewRound() { bytes_lost_in_round_ = 0; loss_events_in_round_ = 0; max_bytes_delivered_in_round_ = 0; + min_bytes_in_flight_in_round_ = 0; } void Bbr2NetworkModel::cap_inflight_lo(QuicByteCount cap) { @@ -387,13 +398,10 @@ QuicByteCount Bbr2NetworkModel::inflight_hi_with_headroom() const { return inflight_hi_ > headroom ? inflight_hi_ - headroom : 0; } -Bbr2NetworkModel::BandwidthGrowth Bbr2NetworkModel::CheckBandwidthGrowth( +bool Bbr2NetworkModel::HasBandwidthGrowth( const Bbr2CongestionEvent& congestion_event) { QUICHE_DCHECK(!full_bandwidth_reached_); QUICHE_DCHECK(congestion_event.end_of_round_trip); - if (congestion_event.last_sample_is_app_limited) { - return APP_LIMITED; - } QuicBandwidth threshold = full_bandwidth_baseline_ * Params().startup_full_bw_threshold; @@ -404,14 +412,15 @@ Bbr2NetworkModel::BandwidthGrowth Bbr2NetworkModel::CheckBandwidthGrowth( << " (Still growing) @ " << congestion_event.event_time; full_bandwidth_baseline_ = MaxBandwidth(); rounds_without_bandwidth_growth_ = 0; - return GROWTH; + return true; } - ++rounds_without_bandwidth_growth_; - BandwidthGrowth return_value = NO_GROWTH; - if (rounds_without_bandwidth_growth_ >= Params().startup_full_bw_rounds) { + + // full_bandwidth_reached is only set to true when not app-limited, except + // when exit_startup_on_persistent_queue is true. + if (rounds_without_bandwidth_growth_ >= Params().startup_full_bw_rounds && + !congestion_event.last_packet_send_state.is_app_limited) { full_bandwidth_reached_ = true; - return_value = EXIT; } QUIC_DVLOG(3) << " CheckBandwidthGrowth at end of round. max_bandwidth:" << MaxBandwidth() << ", threshold:" << threshold @@ -419,7 +428,27 @@ Bbr2NetworkModel::BandwidthGrowth Bbr2NetworkModel::CheckBandwidthGrowth( << " full_bw_reached:" << full_bandwidth_reached_ << " @ " << congestion_event.event_time; - return return_value; + return false; +} + +bool Bbr2NetworkModel::CheckPersistentQueue( + const Bbr2CongestionEvent& congestion_event, float bdp_gain) { + QUICHE_DCHECK(congestion_event.end_of_round_trip); + QuicByteCount target = bdp_gain * BDP(); + if (bdp_gain >= 2) { + // Use a more conservative threshold for STARTUP because CWND gain is 2. + if (target <= QueueingThresholdExtraBytes()) { + return false; + } + target -= QueueingThresholdExtraBytes(); + } else { + target += QueueingThresholdExtraBytes(); + } + if (min_bytes_in_flight_in_round_ > target) { + full_bandwidth_reached_ = true; + return true; + } + return false; } } // namespace quic diff --git a/gquiche/quic/core/congestion_control/bbr2_misc.h b/gquiche/quic/core/congestion_control/bbr2_misc.h index d04d5e23..c2868b13 100644 --- a/gquiche/quic/core/congestion_control/bbr2_misc.h +++ b/gquiche/quic/core/congestion_control/bbr2_misc.h @@ -9,6 +9,7 @@ #include #include "gquiche/quic/core/congestion_control/bandwidth_sampler.h" +#include "gquiche/quic/core/congestion_control/send_algorithm_interface.h" #include "gquiche/quic/core/congestion_control/windowed_filter.h" #include "gquiche/quic/core/quic_bandwidth.h" #include "gquiche/quic/core/quic_packet_number.h" @@ -92,7 +93,15 @@ struct QUIC_EXPORT_PRIVATE Bbr2Params { // If true, always exit STARTUP on loss, even if bandwidth exceeds threshold. // If false, exit STARTUP on loss only if bandwidth is below threshold. - bool always_exit_startup_on_excess_loss = true; + bool always_exit_startup_on_excess_loss = false; + + // If true, include extra acked during STARTUP and proactively reduce extra + // acked when bandwidth increases. + bool startup_include_extra_acked = false; + + // If true, exit STARTUP if bytes in flight has not gone below 2 * BDP at + // any point in the last round. + bool exit_startup_on_persistent_queue = false; /* * DRAIN parameters. @@ -130,6 +139,11 @@ struct QUIC_EXPORT_PRIVATE Bbr2Params { // Multiplier to get target inflight (as multiple of BDP) for PROBE_UP phase. float probe_bw_probe_inflight_gain = 1.25; + // When attempting to grow inflight_hi in PROBE_UP, check whether we are cwnd + // limited before the current aggregation epoch, instead of before the current + // ack event. + bool probe_bw_check_cwnd_limited_before_aggregation_epoch = false; + // Pacing gains. float probe_bw_probe_up_pacing_gain = 1.25; float probe_bw_probe_down_pacing_gain = 0.75; @@ -137,6 +151,13 @@ struct QUIC_EXPORT_PRIVATE Bbr2Params { float probe_bw_cwnd_gain = 2.0; + /* + * PROBE_UP parameters. + */ + bool probe_up_includes_acks_after_cwnd_limited = false; + bool probe_up_dont_exit_if_no_queue_ = false; + bool probe_up_ignore_inflight_hi = false; + /* * PROBE_RTT parameters. */ @@ -176,10 +197,6 @@ struct QUIC_EXPORT_PRIVATE Bbr2Params { // Can be disabled by connection option 'B2RP'. bool avoid_unnecessary_probe_rtt = true; - // Can be disabled by connection option 'B2CL'. - bool avoid_too_low_probe_bw_cwnd = - GetQuicReloadableFlag(quic_bbr2_avoid_too_low_probe_bw_cwnd); - // Can be enabled by connection option 'B2LO'. bool ignore_inflight_lo = false; @@ -299,12 +316,6 @@ struct QUIC_EXPORT_PRIVATE Bbr2CongestionEvent { // Whether acked_packets indicates the end of a round trip. bool end_of_round_trip = false; - // TODO(wub): After deprecating --quic_one_bw_sample_per_ack_event, use - // last_packet_send_state.is_app_limited instead of this field. - // Whether the last bandwidth sample from acked_packets is app limited. - // false if acked_packets is empty. - bool last_sample_is_app_limited = false; - // When the event happened, whether the sender is probing for bandwidth. bool is_probing_for_bandwidth = false; @@ -313,6 +324,8 @@ struct QUIC_EXPORT_PRIVATE Bbr2CongestionEvent { QuicTime::Delta sample_min_rtt = QuicTime::Delta::Infinite(); // Maximum bandwidth of all bandwidth samples from acked_packets. + // This sample may be app-limited, and will be Zero() if there are no newly + // acknowledged inflight packets. QuicBandwidth sample_max_bandwidth = QuicBandwidth::Zero(); // The send state of the largest packet in acked_packets, unless it is empty. @@ -388,6 +401,15 @@ class QUIC_EXPORT_PRIVATE Bbr2NetworkModel { return bandwidth_sampler_.max_ack_height(); } + // 2 packets. Used to indicate the typical number of bytes ACKed at once. + QuicByteCount QueueingThresholdExtraBytes() const { + return 2 * kDefaultTCPMSS; + } + + bool cwnd_limited_before_aggregation_epoch() const { + return cwnd_limited_before_aggregation_epoch_; + } + void EnableOverestimateAvoidance() { bandwidth_sampler_.EnableOverestimateAvoidance(); } @@ -404,6 +426,22 @@ class QUIC_EXPORT_PRIVATE Bbr2NetworkModel { return bandwidth_sampler_.num_ack_aggregation_epochs(); } + void SetStartNewAggregationEpochAfterFullRound(bool value) { + bandwidth_sampler_.SetStartNewAggregationEpochAfterFullRound(value); + } + + void SetLimitMaxAckHeightTrackerBySendRate(bool value) { + bandwidth_sampler_.SetLimitMaxAckHeightTrackerBySendRate(value); + } + + void SetMaxAckHeightTrackerWindowLength(QuicRoundTripCount value) { + bandwidth_sampler_.SetMaxAckHeightTrackerWindowLength(value); + } + + void SetReduceExtraAckedOnBandwidthIncrease(bool value) { + bandwidth_sampler_.SetReduceExtraAckedOnBandwidthIncrease(value); + } + bool MaybeExpireMinRtt(const Bbr2CongestionEvent& congestion_event); QuicBandwidth BandwidthEstimate() const { @@ -414,30 +452,21 @@ class QUIC_EXPORT_PRIVATE Bbr2NetworkModel { return round_trip_counter_.Count(); } - bool IsCongestionWindowLimited( - const Bbr2CongestionEvent& congestion_event) const; - // Return true if the number of loss events exceeds max_loss_events and // fraction of bytes lost exceed the loss threshold. bool IsInflightTooHigh(const Bbr2CongestionEvent& congestion_event, int64_t max_loss_events) const; - enum BandwidthGrowth { - APP_LIMITED = 0, - NO_GROWTH = 1, - GROWTH = 2, - EXIT = 3, // Too many rounds without bandwidth growth. - }; - // Check bandwidth growth in the past round. Must be called at the end of a - // round. - // Return APP_LIMITED if the bandwidth sample was app-limited. - // Return GROWTH if the bandwidth grew as expected. - // Return NO_GROWTH if the bandwidth didn't increase enough. - // Return TOO_MANY_ROUNDS_WITH_NO_GROWTH if enough rounds have elapsed without - // growth, also sets |full_bandwidth_reached_| to true. - BandwidthGrowth CheckBandwidthGrowth( - const Bbr2CongestionEvent& congestion_event); + // round. Returns true if there was sufficient bandwidth growth and false + // otherwise. If it's been too many rounds without growth, also sets + // |full_bandwidth_reached_| to true. + bool HasBandwidthGrowth(const Bbr2CongestionEvent& congestion_event); + + // Returns true if the minimum bytes in flight during the round is greater + // than the BDP * |bdp_gain|. + bool CheckPersistentQueue(const Bbr2CongestionEvent& congestion_event, + float bdp_gain); QuicPacketNumber last_sent_packet() const { return round_trip_counter_.last_sent_packet(); @@ -461,6 +490,10 @@ class QUIC_EXPORT_PRIVATE Bbr2NetworkModel { return max_bytes_delivered_in_round_; } + QuicByteCount min_bytes_in_flight_in_round() const { + return min_bytes_in_flight_in_round_; + } + QuicPacketNumber end_of_app_limited_phase() const { return bandwidth_sampler_.end_of_app_limited_phase(); } @@ -531,6 +564,9 @@ class QUIC_EXPORT_PRIVATE Bbr2NetworkModel { // congestion event) was sent and acked, respectively. QuicByteCount max_bytes_delivered_in_round_ = 0; + // The minimum bytes in flight during this round. + QuicByteCount min_bytes_in_flight_in_round_ = 0; + // Max bandwidth in the current round. Updated once per congestion event. QuicBandwidth bandwidth_latest_ = QuicBandwidth::Zero(); // Max bandwidth of recent rounds. Updated once per round. @@ -548,6 +584,10 @@ class QUIC_EXPORT_PRIVATE Bbr2NetworkModel { float cwnd_gain_; float pacing_gain_; + // Whether we are cwnd limited prior to the start of the current aggregation + // epoch. + bool cwnd_limited_before_aggregation_epoch_ = false; + // STARTUP-centric fields which experimentally used by PROBE_UP. bool full_bandwidth_reached_ = false; QuicBandwidth full_bandwidth_baseline_ = QuicBandwidth::Zero(); diff --git a/gquiche/quic/core/congestion_control/bbr2_probe_bw.cc b/gquiche/quic/core/congestion_control/bbr2_probe_bw.cc index 1000ba81..00656a09 100644 --- a/gquiche/quic/core/congestion_control/bbr2_probe_bw.cc +++ b/gquiche/quic/core/congestion_control/bbr2_probe_bw.cc @@ -36,8 +36,7 @@ void Bbr2ProbeBwMode::Enter(QuicTime now, } Bbr2Mode Bbr2ProbeBwMode::OnCongestionEvent( - QuicByteCount prior_in_flight, - QuicTime event_time, + QuicByteCount prior_in_flight, QuicTime event_time, const AckedPacketVector& /*acked_packets*/, const LostPacketVector& /*lost_packets*/, const Bbr2CongestionEvent& congestion_event) { @@ -80,36 +79,17 @@ Bbr2Mode Bbr2ProbeBwMode::OnCongestionEvent( } Limits Bbr2ProbeBwMode::GetCwndLimits() const { - if (!GetQuicReloadableFlag(quic_bbr2_avoid_too_low_probe_bw_cwnd)) { - if (cycle_.phase == CyclePhase::PROBE_CRUISE) { - return NoGreaterThan( - std::min(model_->inflight_lo(), model_->inflight_hi_with_headroom())); - } - + if (cycle_.phase == CyclePhase::PROBE_CRUISE) { return NoGreaterThan( - std::min(model_->inflight_lo(), model_->inflight_hi())); + std::min(model_->inflight_lo(), model_->inflight_hi_with_headroom())); } - - QUIC_RELOADABLE_FLAG_COUNT(quic_bbr2_avoid_too_low_probe_bw_cwnd); - - QuicByteCount upper_limit = - std::min(model_->inflight_lo(), cycle_.phase == CyclePhase::PROBE_CRUISE - ? model_->inflight_hi_with_headroom() - : model_->inflight_hi()); - - if (Params().avoid_too_low_probe_bw_cwnd) { - // Ensure upper_limit is at least BDP + AckHeight. - QuicByteCount bdp_with_ack_height = - model_->BDP(model_->MaxBandwidth()) + model_->MaxAckHeight(); - if (upper_limit < bdp_with_ack_height) { - QUIC_DVLOG(3) << sender_ << " Rasing upper_limit from " << upper_limit - << " to " << bdp_with_ack_height; - QUIC_CODE_COUNT(quic_bbr2_avoid_too_low_probe_bw_cwnd_in_effect); - upper_limit = bdp_with_ack_height; - } + if (Params().probe_up_ignore_inflight_hi && + cycle_.phase == CyclePhase::PROBE_UP) { + // Similar to STARTUP. + return NoGreaterThan(model_->inflight_lo()); } - return NoGreaterThan(upper_limit); + return NoGreaterThan(std::min(model_->inflight_lo(), model_->inflight_hi())); } bool Bbr2ProbeBwMode::IsProbingForBandwidth() const { @@ -135,7 +115,7 @@ void Bbr2ProbeBwMode::UpdateProbeDown( if (cycle_.rounds_in_phase == 1 && congestion_event.end_of_round_trip) { cycle_.is_sample_from_probing = false; - if (!congestion_event.last_sample_is_app_limited) { + if (!congestion_event.last_packet_send_state.is_app_limited) { QUIC_DVLOG(2) << sender_ << " Advancing max bw filter after one round in PROBE_DOWN."; @@ -211,12 +191,20 @@ Bbr2ProbeBwMode::AdaptUpperBoundsResult Bbr2ProbeBwMode::MaybeAdaptUpperBounds( << congestion_event.last_packet_send_state.total_bytes_acked << ")"; } } + // TODO(ianswett): Inflight too high is really checking for loss, not + // inflight. if (model_->IsInflightTooHigh(congestion_event, Params().probe_bw_full_loss_count)) { if (cycle_.is_sample_from_probing) { cycle_.is_sample_from_probing = false; - - if (!send_state.is_app_limited) { + if (!send_state.is_app_limited || + Params().probe_up_dont_exit_if_no_queue_) { + if (send_state.is_app_limited) { + // If there's excess loss or a queue is building, exit even if the + // last sample was app limited. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_no_probe_up_exit_if_no_queue, + 2, 2); + } const QuicByteCount inflight_target = sender_->GetTargetBytesInflight() * (1.0 - Params().beta); if (inflight_at_send >= inflight_target) { @@ -363,12 +351,53 @@ void Bbr2ProbeBwMode::RaiseInflightHighSlope() { void Bbr2ProbeBwMode::ProbeInflightHighUpward( const Bbr2CongestionEvent& congestion_event) { QUICHE_DCHECK_EQ(cycle_.phase, CyclePhase::PROBE_UP); - if (!model_->IsCongestionWindowLimited(congestion_event)) { - QUIC_DVLOG(3) << sender_ - << " Raising inflight_hi early return: Not cwnd limited."; - // Not fully utilizing cwnd, so can't safely grow. + if (Params().probe_up_ignore_inflight_hi) { + // When inflight_hi is disabled in PROBE_UP, it increases when + // the number of bytes delivered in a round is larger inflight_hi. return; } + if (Params().probe_bw_check_cwnd_limited_before_aggregation_epoch) { + if (!model_->cwnd_limited_before_aggregation_epoch()) { + QUIC_DVLOG(3) << sender_ + << " Raising inflight_hi early return: Not cwnd limited " + "before aggregation epoch."; + // Not fully utilizing cwnd, so can't safely grow. + return; + } + } else if (Params().probe_up_includes_acks_after_cwnd_limited) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_bbr2_add_bytes_acked_after_inflight_hi_limited); + // Don't continue adding bytes to probe_up_acked if the sender was not + // app-limited after being inflight_hi limited at least once. + if (!cycle_.probe_up_app_limited_since_inflight_hi_limited_ || + congestion_event.last_packet_send_state.is_app_limited) { + cycle_.probe_up_app_limited_since_inflight_hi_limited_ = false; + if (congestion_event.prior_bytes_in_flight < + congestion_event.prior_cwnd) { + QUIC_DVLOG(3) << sender_ + << " Raising inflight_hi early return: Not cwnd limited."; + // Not fully utilizing cwnd, so can't safely grow. + return; + } + + if (congestion_event.prior_cwnd < model_->inflight_hi()) { + QUIC_DVLOG(3) + << sender_ + << " Raising inflight_hi early return: inflight_hi not fully used."; + // Not fully using inflight_hi, so don't grow it. + return; + } + } + // Start a new period of adding bytes_acked, because inflight_hi limited. + cycle_.probe_up_app_limited_since_inflight_hi_limited_ = true; + } else { + if (congestion_event.prior_bytes_in_flight < congestion_event.prior_cwnd) { + QUIC_DVLOG(3) << sender_ + << " Raising inflight_hi early return: Not cwnd limited."; + // Not fully utilizing cwnd, so can't safely grow. + return; + } + } if (congestion_event.prior_cwnd < model_->inflight_hi()) { QUIC_DVLOG(3) @@ -455,25 +484,33 @@ void Bbr2ProbeBwMode::UpdateProbeUp( // TCP uses min_rtt instead of a full round: // HasPhaseLasted(model_->MinRtt(), congestion_event) } else if (cycle_.rounds_in_phase > 0) { - const QuicByteCount bdp = model_->BDP(); - QuicByteCount queuing_threshold_extra_bytes = 2 * kDefaultTCPMSS; - if (Params().add_ack_height_to_queueing_threshold) { - queuing_threshold_extra_bytes += model_->MaxAckHeight(); + if (Params().probe_up_dont_exit_if_no_queue_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_no_probe_up_exit_if_no_queue, 1, + 2); + is_queuing = congestion_event.end_of_round_trip && + model_->CheckPersistentQueue( + congestion_event, Params().probe_bw_probe_inflight_gain); + } else { + QuicByteCount queuing_threshold_extra_bytes = + model_->QueueingThresholdExtraBytes(); + if (Params().add_ack_height_to_queueing_threshold) { + queuing_threshold_extra_bytes += model_->MaxAckHeight(); + } + QuicByteCount queuing_threshold = + (Params().probe_bw_probe_inflight_gain * model_->BDP()) + + queuing_threshold_extra_bytes; + + is_queuing = congestion_event.bytes_in_flight >= queuing_threshold; + + QUIC_DVLOG(3) << sender_ + << " Checking if building up a queue. prior_in_flight:" + << prior_in_flight + << ", post_in_flight:" << congestion_event.bytes_in_flight + << ", threshold:" << queuing_threshold + << ", is_queuing:" << is_queuing + << ", max_bw:" << model_->MaxBandwidth() + << ", min_rtt:" << model_->MinRtt(); } - QuicByteCount queuing_threshold = - (Params().probe_bw_probe_inflight_gain * bdp) + - queuing_threshold_extra_bytes; - - is_queuing = congestion_event.bytes_in_flight >= queuing_threshold; - - QUIC_DVLOG(3) << sender_ - << " Checking if building up a queue. prior_in_flight:" - << prior_in_flight - << ", post_in_flight:" << congestion_event.bytes_in_flight - << ", threshold:" << queuing_threshold - << ", is_queuing:" << is_queuing - << ", max_bw:" << model_->MaxBandwidth() - << ", min_rtt:" << model_->MinRtt(); } if (is_risky || is_queuing) { @@ -483,8 +520,7 @@ void Bbr2ProbeBwMode::UpdateProbeUp( } void Bbr2ProbeBwMode::EnterProbeDown(bool probed_too_high, - bool stopped_risky_probe, - QuicTime now) { + bool stopped_risky_probe, QuicTime now) { QUIC_DVLOG(2) << sender_ << " Phase change: " << cycle_.phase << " ==> " << CyclePhase::PROBE_DOWN << " after " << now - cycle_.phase_start_time << ", or " @@ -500,6 +536,13 @@ void Bbr2ProbeBwMode::EnterProbeDown(bool probed_too_high, cycle_.rounds_in_phase = 0; cycle_.phase_start_time = now; ++sender_->connection_stats_->bbr_num_cycles; + if (Params().bw_lo_mode_ != Bbr2Params::QuicBandwidthLoMode::DEFAULT) { + // Clear bandwidth lo if it was set in PROBE_UP, because losses in PROBE_UP + // should not permanently change bandwidth_lo. + // It's possible for bandwidth_lo to be set during REFILL, but if that was + // a valid value, it'll quickly be rediscovered. + model_->clear_bandwidth_lo(); + } // Pick probe wait time. cycle_.rounds_since_probe = @@ -510,6 +553,7 @@ void Bbr2ProbeBwMode::EnterProbeDown(bool probed_too_high, Params().probe_bw_probe_max_rand_duration.ToMicroseconds())); cycle_.probe_up_bytes = std::numeric_limits::max(); + cycle_.probe_up_app_limited_since_inflight_hi_limited_ = false; cycle_.has_advanced_max_bw = false; model_->RestartRoundEarly(); } @@ -617,9 +661,7 @@ std::ostream& operator<<(std::ostream& os, return os; } -const Bbr2Params& Bbr2ProbeBwMode::Params() const { - return sender_->Params(); -} +const Bbr2Params& Bbr2ProbeBwMode::Params() const { return sender_->Params(); } float Bbr2ProbeBwMode::PacingGainForPhase( Bbr2ProbeBwMode::CyclePhase phase) const { diff --git a/gquiche/quic/core/congestion_control/bbr2_probe_bw.h b/gquiche/quic/core/congestion_control/bbr2_probe_bw.h index eb67c311..67b5f6ee 100644 --- a/gquiche/quic/core/congestion_control/bbr2_probe_bw.h +++ b/gquiche/quic/core/congestion_control/bbr2_probe_bw.h @@ -118,6 +118,7 @@ class QUIC_EXPORT_PRIVATE Bbr2ProbeBwMode final : public Bbr2ModeBase { uint64_t probe_up_rounds = 0; QuicByteCount probe_up_bytes = std::numeric_limits::max(); QuicByteCount probe_up_acked = 0; + bool probe_up_app_limited_since_inflight_hi_limited_ = false; // Whether max bandwidth filter window has advanced in this cycle. It is // advanced once per cycle. bool has_advanced_max_bw = false; diff --git a/gquiche/quic/core/congestion_control/bbr2_sender.cc b/gquiche/quic/core/congestion_control/bbr2_sender.cc index 84a21215..bd73b9ed 100644 --- a/gquiche/quic/core/congestion_control/bbr2_sender.cc +++ b/gquiche/quic/core/congestion_control/bbr2_sender.cc @@ -16,6 +16,7 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/common/print_elements.h" namespace quic { @@ -101,10 +102,6 @@ void Bbr2Sender::SetFromConfig(const QuicConfig& config, if (config.HasClientRequestedIndependentOption(kB2RP, perspective)) { params_.avoid_unnecessary_probe_rtt = false; } - if (GetQuicReloadableFlag(quic_bbr2_avoid_too_low_probe_bw_cwnd) && - config.HasClientRequestedIndependentOption(kB2CL, perspective)) { - params_.avoid_too_low_probe_bw_cwnd = false; - } if (config.HasClientRequestedIndependentOption(k1RTT, perspective)) { params_.startup_full_bw_rounds = 1; } @@ -123,6 +120,16 @@ void Bbr2Sender::SetFromConfig(const QuicConfig& config, void Bbr2Sender::ApplyConnectionOptions( const QuicTagVector& connection_options) { + if (GetQuicReloadableFlag(quic_bbr2_extra_acked_window) && + ContainsQuicTag(connection_options, kBBR4)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_extra_acked_window, 1, 2); + model_.SetMaxAckHeightTrackerWindowLength(20); + } + if (GetQuicReloadableFlag(quic_bbr2_extra_acked_window) && + ContainsQuicTag(connection_options, kBBR5)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_extra_acked_window, 2, 2); + model_.SetMaxAckHeightTrackerWindowLength(40); + } if (ContainsQuicTag(connection_options, kBBQ2)) { params_.startup_cwnd_gain = 2.885; params_.drain_cwnd_gain = 2.885; @@ -132,7 +139,7 @@ void Bbr2Sender::ApplyConnectionOptions( params_.ignore_inflight_lo = true; } if (ContainsQuicTag(connection_options, kB2NE)) { - params_.always_exit_startup_on_excess_loss = false; + params_.always_exit_startup_on_excess_loss = true; } if (ContainsQuicTag(connection_options, kB2SL)) { params_.startup_loss_exit_use_max_delivered_for_inflight_hi = false; @@ -161,6 +168,53 @@ void Bbr2Sender::ApplyConnectionOptions( if (ContainsQuicTag(connection_options, kBBQ9)) { params_.bw_lo_mode_ = Bbr2Params::QuicBandwidthLoMode::CWND_REDUCTION; } + if (ContainsQuicTag(connection_options, kB201)) { + params_.probe_bw_check_cwnd_limited_before_aggregation_epoch = true; + } + if (GetQuicReloadableFlag(quic_bbr2_no_probe_up_exit_if_no_queue) && + ContainsQuicTag(connection_options, kB202)) { + params_.probe_up_dont_exit_if_no_queue_ = true; + } + if (GetQuicReloadableFlag(quic_bbr2_ignore_inflight_hi_in_probe_up) && + ContainsQuicTag(connection_options, kB203)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_bbr2_ignore_inflight_hi_in_probe_up); + params_.probe_up_ignore_inflight_hi = true; + } + if (GetQuicReloadableFlag(quic_bbr2_startup_extra_acked) && + ContainsQuicTag(connection_options, kB204)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_startup_extra_acked, 1, 2); + model_.SetReduceExtraAckedOnBandwidthIncrease(true); + } + if (GetQuicReloadableFlag(quic_bbr2_startup_extra_acked) && + ContainsQuicTag(connection_options, kB205)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_bbr2_startup_extra_acked, 2, 2); + params_.startup_include_extra_acked = true; + } + if (GetQuicReloadableFlag(quic_bbr2_exit_startup_on_persistent_queue2) && + ContainsQuicTag(connection_options, kB207)) { + params_.exit_startup_on_persistent_queue = true; + } + + if (ContainsQuicTag(connection_options, kBBRA)) { + model_.SetStartNewAggregationEpochAfterFullRound(true); + } + if (GetQuicReloadableFlag(quic_bbr_use_send_rate_in_max_ack_height_tracker) && + ContainsQuicTag(connection_options, kBBRB)) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_bbr_use_send_rate_in_max_ack_height_tracker, 2, 2); + model_.SetLimitMaxAckHeightTrackerBySendRate(true); + } + if (GetQuicReloadableFlag( + quic_bbr2_add_bytes_acked_after_inflight_hi_limited) && + ContainsQuicTag(connection_options, kBBQ0)) { + params_.probe_up_includes_acks_after_cwnd_limited = true; + } + + if (GetQuicReloadableFlag(quic_bbr2_startup_probe_up_loss_events) && + ContainsQuicTag(connection_options, kB206)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_bbr2_startup_probe_up_loss_events); + params_.startup_full_loss_count = params_.probe_bw_full_loss_count; + } } Limits Bbr2Sender::GetCwndLimitsByMode() const { @@ -280,14 +334,16 @@ void Bbr2Sender::OnCongestionEvent(bool /*rtt_updated*/, model_.OnCongestionEventFinish(unacked_packets_->GetLeastUnacked(), congestion_event); - last_sample_is_app_limited_ = congestion_event.last_sample_is_app_limited; + last_sample_is_app_limited_ = + congestion_event.last_packet_send_state.is_app_limited; if (congestion_event.bytes_in_flight == 0 && params().avoid_unnecessary_probe_rtt) { OnEnterQuiescence(event_time); } QUIC_DVLOG(3) - << this << " END CongestionEvent(acked:" << acked_packets + << this + << " END CongestionEvent(acked:" << quiche::PrintElements(acked_packets) << ", lost:" << lost_packets.size() << ") " << ", Mode:" << mode_ << ", RttCount:" << model_.RoundTripCount() << ", BytesInFlight:" << congestion_event.bytes_in_flight @@ -346,7 +402,7 @@ void Bbr2Sender::UpdateCongestionWindow(QuicByteCount bytes_acked) { QuicByteCount target_cwnd = GetTargetCongestionWindow(model_.cwnd_gain()); const QuicByteCount prior_cwnd = cwnd_; - if (model_.full_bandwidth_reached()) { + if (model_.full_bandwidth_reached() || Params().startup_include_extra_acked) { target_cwnd += model_.MaxAckHeight(); cwnd_ = std::min(prior_cwnd + bytes_acked, target_cwnd); } else if (prior_cwnd < target_cwnd || prior_cwnd < 2 * initial_cwnd_) { @@ -475,6 +531,10 @@ Bbr2Sender::DebugState Bbr2Sender::ExportDebugState() const { s.inflight_lo = model_.inflight_lo(); s.max_ack_height = model_.MaxAckHeight(); s.min_rtt = model_.MinRtt(); + s.latest_rtt = rtt_stats_->latest_rtt(); + s.mean_deviation = rtt_stats_->mean_deviation(); + s.smoothed_rtt = rtt_stats_->smoothed_rtt(); + s.min_rtt_timestamp = model_.MinRttTimestamp(); s.congestion_window = cwnd_; s.pacing_rate = pacing_rate_; diff --git a/gquiche/quic/core/congestion_control/bbr2_sender.h b/gquiche/quic/core/congestion_control/bbr2_sender.h index 9e1f5d11..d1b1d3b6 100644 --- a/gquiche/quic/core/congestion_control/bbr2_sender.h +++ b/gquiche/quic/core/congestion_control/bbr2_sender.h @@ -123,6 +123,10 @@ class QUIC_EXPORT_PRIVATE Bbr2Sender final : public SendAlgorithmInterface { QuicByteCount inflight_lo; QuicByteCount max_ack_height; QuicTime::Delta min_rtt = QuicTime::Delta::Zero(); + QuicTime::Delta latest_rtt = QuicTime::Delta::Zero(); + QuicTime::Delta smoothed_rtt = QuicTime::Delta::Zero(); + QuicTime::Delta mean_deviation = QuicTime::Delta::Zero(); + QuicTime min_rtt_timestamp = QuicTime::Zero(); QuicByteCount congestion_window; QuicBandwidth pacing_rate = QuicBandwidth::Zero(); diff --git a/gquiche/quic/core/congestion_control/bbr2_simulator_test.cc b/gquiche/quic/core/congestion_control/bbr2_simulator_test.cc index 86b363cf..b4b6a7de 100644 --- a/gquiche/quic/core/congestion_control/bbr2_simulator_test.cc +++ b/gquiche/quic/core/congestion_control/bbr2_simulator_test.cc @@ -202,6 +202,8 @@ class Bbr2DefaultTopologyTest : public Bbr2SimulatorTest { GetQuicFlag(FLAGS_quic_max_congestion_window), &random_, QuicConnectionPeer::GetStats(endpoint->connection()), old_sender); QuicConnectionPeer::SetSendAlgorithm(endpoint->connection(), sender); + const int kTestMaxPacketSize = 1350; + endpoint->connection()->SetMaxPacketLength(kTestMaxPacketSize); endpoint->RecordTrace(); return sender; } @@ -308,11 +310,15 @@ class Bbr2DefaultTopologyTest : public Bbr2SimulatorTest { } void SetConnectionOption(QuicTag option) { + SetConnectionOption(std::move(option), sender_); + } + + void SetConnectionOption(QuicTag option, Bbr2Sender* sender) { QuicConfig config; QuicTagVector options; options.push_back(option); QuicConfigPeer::SetReceivedConnectionOptions(&config, options); - sender_->SetFromConfig(config, Perspective::IS_SERVER); + sender->SetFromConfig(config, Perspective::IS_SERVER); } bool Bbr2ModeIsOneOf(const std::vector& expected_modes) const { @@ -381,9 +387,74 @@ TEST_F(Bbr2DefaultTopologyTest, NormalStartup) { 3u, sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); EXPECT_EQ(0u, sender_connection_stats().packets_lost); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); } +TEST_F(Bbr2DefaultTopologyTest, NormalStartupB207) { + SetQuicReloadableFlag(quic_bbr2_exit_startup_on_persistent_queue2, true); + SetConnectionOption(kB207); + DefaultTopologyParams params; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(1u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 1u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); +} + +// Add extra_acked to CWND in STARTUP and exit STARTUP on a persistent queue. +TEST_F(Bbr2DefaultTopologyTest, NormalStartupB207andB205) { + SetQuicReloadableFlag(quic_bbr2_startup_extra_acked, true); + SetQuicReloadableFlag(quic_bbr2_exit_startup_on_persistent_queue2, true); + SetConnectionOption(kB205); + SetConnectionOption(kB207); + DefaultTopologyParams params; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(1u, sender_->ExportDebugState().round_trip_count - max_bw_round); + EXPECT_EQ( + 2u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + EXPECT_EQ(0u, sender_connection_stats().packets_lost); +} + // Test a simple long data transfer in the default setup. TEST_F(Bbr2DefaultTopologyTest, SimpleTransfer) { DefaultTopologyParams params; @@ -417,8 +488,8 @@ TEST_F(Bbr2DefaultTopologyTest, SimpleTransfer) { EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->smoothed_rtt(), 1.0f); } -TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB2NE) { - SetConnectionOption(kB2NE); +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB2RC) { + SetConnectionOption(kB2RC); DefaultTopologyParams params; CreateNetwork(params); @@ -436,8 +507,108 @@ TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB2NE) { EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); } -TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB2RC) { - SetConnectionOption(kB2RC); +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB201) { + SetConnectionOption(kB201); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB206) { + SetQuicReloadableFlag(quic_bbr2_startup_probe_up_loss_events, true); + SetConnectionOption(kB206); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferB207) { + SetQuicReloadableFlag(quic_bbr2_exit_startup_on_persistent_queue2, true); + SetConnectionOption(kB207); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferBBRB) { + SetQuicReloadableFlag(quic_bbr_use_send_rate_in_max_ack_height_tracker, true); + SetConnectionOption(kBBRB); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferBBR4) { + SetQuicReloadableFlag(quic_bbr2_extra_acked_window, true); + SetConnectionOption(kBBR4); + DefaultTopologyParams params; + CreateNetwork(params); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.01f); + + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + // The margin here is high, because the aggregation greatly increases + // smoothed rtt. + EXPECT_GE(params.RTT() * 4, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransferBBR5) { + SetQuicReloadableFlag(quic_bbr2_extra_acked_window, true); + SetConnectionOption(kBBR5); DefaultTopologyParams params; CreateNetwork(params); @@ -496,7 +667,37 @@ TEST_F(Bbr2DefaultTopologyTest, SimpleTransfer2RTTAggregationBytes) { EXPECT_APPROX_EQ(params.BottleneckBandwidth(), sender_->ExportDebugState().bandwidth_hi, 0.01f); - EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + if (GetQuicReloadableFlag(quic_fix_pacing_sender_bursts)) { + EXPECT_EQ(sender_loss_rate_in_packets(), 0); + } else { + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + } + // The margin here is high, because both link level aggregation and ack + // decimation can greatly increase smoothed rtt. + EXPECT_GE(params.RTT() * 5, rtt_stats()->smoothed_rtt()); + EXPECT_APPROX_EQ(params.RTT(), rtt_stats()->min_rtt(), 0.2f); +} + +TEST_F(Bbr2DefaultTopologyTest, SimpleTransfer2RTTAggregationBytesB201) { + SetConnectionOption(kB201); + DefaultTopologyParams params; + CreateNetwork(params); + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Transfer 12MB. + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(35)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + + // TODO(wub): Tighten the error bound once BSAO is default enabled. + EXPECT_APPROX_EQ(params.BottleneckBandwidth(), + sender_->ExportDebugState().bandwidth_hi, 0.5f); + + if (GetQuicReloadableFlag(quic_fix_pacing_sender_bursts)) { + EXPECT_LE(sender_loss_rate_in_packets(), 0.01); + } else { + EXPECT_LE(sender_loss_rate_in_packets(), 0.05); + } // The margin here is high, because both link level aggregation and ack // decimation can greatly increase smoothed rtt. EXPECT_GE(params.RTT() * 5, rtt_stats()->smoothed_rtt()); @@ -577,6 +778,375 @@ TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncrease)) { [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, QuicTime::Delta::FromSeconds(50)); EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BBQ0 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseBBQ0)) { + SetQuicReloadableFlag(quic_bbr2_add_bytes_acked_after_inflight_hi_limited, + true); + SetConnectionOption(kBBQ0); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with BBQ0 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseBBQ0Aggregation)) { + SetQuicReloadableFlag(quic_bbr2_add_bytes_acked_after_inflight_hi_limited, + true); + SetConnectionOption(kBBQ0); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + // TODO(ianswett) Make these bound tighter once overestimation is reduced. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.6f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.90f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B202 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB202)) { + SetQuicReloadableFlag(quic_bbr2_no_probe_up_exit_if_no_queue, true); + SetConnectionOption(kB202); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.1f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B202 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseB202Aggregation)) { + SetQuicReloadableFlag(quic_bbr2_no_probe_up_exit_if_no_queue, true); + SetConnectionOption(kB202); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.6f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.92f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B203 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB203)) { + SetQuicReloadableFlag(quic_bbr2_ignore_inflight_hi_in_probe_up, true); + SetConnectionOption(kB203); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.30); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B203 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseB203Aggregation)) { + SetQuicReloadableFlag(quic_bbr2_ignore_inflight_hi_in_probe_up, true); + SetConnectionOption(kB203); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.60f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.91f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B204 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB204)) { + SetQuicReloadableFlag(quic_bbr2_startup_extra_acked, true); + SetConnectionOption(kB204); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.25); + EXPECT_LE(sender_->ExportDebugState().max_ack_height, 2000u); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.02f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B204 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseB204Aggregation)) { + SetQuicReloadableFlag(quic_bbr2_startup_extra_acked, true); + SetConnectionOption(kB204); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, and B204 actually + // is increasing overestimation, which is surprising. + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.60f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.35); + EXPECT_LE(sender_->ExportDebugState().max_ack_height, 10000u); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 10% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.95f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B205 +TEST_F(Bbr2DefaultTopologyTest, QUIC_SLOW_TEST(BandwidthIncreaseB205)) { + SetQuicReloadableFlag(quic_bbr2_startup_extra_acked, true); + SetConnectionOption(kB205); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + sender_endpoint_.AddBytesToTransfer(10 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.1f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.10); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure the full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.1f); +} + +// Test Bbr2's reaction to a 100x bandwidth increase during a transfer with B205 +// in the presence of ACK aggregation. +TEST_F(Bbr2DefaultTopologyTest, + QUIC_SLOW_TEST(BandwidthIncreaseB205Aggregation)) { + SetQuicReloadableFlag(quic_bbr2_startup_extra_acked, true); + SetConnectionOption(kB205); + DefaultTopologyParams params; + params.local_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(15000); + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(100); + CreateNetwork(params); + + // 2 RTTs of aggregation, with a max of 10kb. + EnableAggregation(10 * 1024, 2 * params.RTT()); + + // Reduce the payload to 2MB because 10MB takes too long. + sender_endpoint_.AddBytesToTransfer(2 * 1024 * 1024); + + simulator_.RunFor(QuicTime::Delta::FromSeconds(15)); + EXPECT_TRUE(Bbr2ModeIsOneOf({Bbr2Mode::PROBE_BW, Bbr2Mode::PROBE_RTT})); + QUIC_LOG(INFO) << "Bandwidth increasing at time " << SimulatedNow(); + + // This is much farther off when aggregation is present, + // Ideally BSAO or another option would fix this. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_est, 0.45f); + EXPECT_LE(sender_loss_rate_in_packets(), 0.15); + + // Now increase the bottleneck bandwidth from 100Kbps to 10Mbps. + params.test_link.bandwidth = QuicBandwidth::FromKBitsPerSecond(10000); + TestLink()->set_bandwidth(params.test_link.bandwidth); + + bool simulator_result = simulator_.RunUntilOrTimeout( + [this]() { return sender_endpoint_.bytes_to_transfer() == 0; }, + QuicTime::Delta::FromSeconds(50)); + EXPECT_TRUE(simulator_result); + // Ensure at least 5% of full bandwidth is discovered. + EXPECT_APPROX_EQ(params.test_link.bandwidth, + sender_->ExportDebugState().bandwidth_hi, 0.9f); } // Test the number of losses incurred by the startup phase in a situation when @@ -859,6 +1429,42 @@ TEST_F(Bbr2DefaultTopologyTest, ExitStartupDueToLossB2SL) { EXPECT_APPROX_EQ(sender_->ExportDebugState().inflight_hi, params.BDP(), 0.1f); } +// Verifies that in STARTUP, if we exceed loss threshold in a round, we exit +// STARTUP at the end of the round even if there's enough bandwidth growth. +TEST_F(Bbr2DefaultTopologyTest, ExitStartupDueToLossB2NE) { + // Set up flags such that any loss will be considered "too high". + SetQuicFlag(FLAGS_quic_bbr2_default_startup_full_loss_count, 0); + SetQuicFlag(FLAGS_quic_bbr2_default_loss_threshold, 0.0); + + sender_ = SetupBbr2Sender(&sender_endpoint_, /*old_sender=*/nullptr); + + SetConnectionOption(kB2NE); + DefaultTopologyParams params; + params.switch_queue_capacity_in_bdp = 0.5; + CreateNetwork(params); + + // Run until the full bandwidth is reached and check how many rounds it was. + sender_endpoint_.AddBytesToTransfer(12 * 1024 * 1024); + QuicRoundTripCount max_bw_round = 0; + QuicBandwidth max_bw(QuicBandwidth::Zero()); + bool simulator_result = simulator_.RunUntilOrTimeout( + [this, &max_bw, &max_bw_round]() { + if (max_bw < sender_->ExportDebugState().bandwidth_hi) { + max_bw = sender_->ExportDebugState().bandwidth_hi; + max_bw_round = sender_->ExportDebugState().round_trip_count; + } + return sender_->ExportDebugState().startup.full_bandwidth_reached; + }, + QuicTime::Delta::FromSeconds(5)); + ASSERT_TRUE(simulator_result); + EXPECT_EQ(Bbr2Mode::DRAIN, sender_->ExportDebugState().mode); + EXPECT_EQ(sender_->ExportDebugState().round_trip_count, max_bw_round); + EXPECT_EQ( + 0u, + sender_->ExportDebugState().startup.round_trips_without_bandwidth_growth); + EXPECT_NE(0u, sender_connection_stats().packets_lost); +} + TEST_F(Bbr2DefaultTopologyTest, SenderPoliced) { DefaultTopologyParams params; params.sender_policer_params = TrafficPolicerParams(); @@ -887,7 +1493,10 @@ TEST_F(Bbr2DefaultTopologyTest, StartupStats) { ASSERT_FALSE(sender_->InSlowStart()); const QuicConnectionStats& stats = sender_connection_stats(); - EXPECT_EQ(1u, stats.slowstart_count); + // The test explicitly replaces the default-created send algorithm with the + // one created by the test. slowstart_count increaments every time a BBR + // sender is created. + EXPECT_GE(stats.slowstart_count, 1u); EXPECT_FALSE(stats.slowstart_duration.IsRunning()); EXPECT_THAT(stats.slowstart_duration.GetTotalElapsedTime(), AllOf(Ge(QuicTime::Delta::FromMilliseconds(500)), @@ -1058,7 +1667,6 @@ TEST_F(Bbr2DefaultTopologyTest, ProbeBwAfterQuiescencePostponeMinRttTimestamp) { min_rtt_timestamp_after_idle); } -// Regression test for http://go/switchtobbr2midconnection. TEST_F(Bbr2DefaultTopologyTest, SwitchToBbr2MidConnection) { QuicTime now = QuicTime::Zero(); BbrSender old_sender(sender_connection()->clock()->Now(), @@ -1358,6 +1966,10 @@ class Bbr2MultiSenderTest : public Bbr2SimulatorTest { kDefaultInitialCwndPackets, GetQuicFlag(FLAGS_quic_max_congestion_window), &random_, QuicConnectionPeer::GetStats(endpoint->connection()), nullptr); + // TODO(ianswett): Add dedicated tests for this option until it becomes + // the default behavior. + SetConnectionOption(sender, kBBRA); + QuicConnectionPeer::SetSendAlgorithm(endpoint->connection(), sender); endpoint->RecordTrace(); return sender; diff --git a/gquiche/quic/core/congestion_control/bbr2_startup.cc b/gquiche/quic/core/congestion_control/bbr2_startup.cc index 99bcb2af..7f82bd69 100644 --- a/gquiche/quic/core/congestion_control/bbr2_startup.cc +++ b/gquiche/quic/core/congestion_control/bbr2_startup.cc @@ -8,6 +8,7 @@ #include "gquiche/quic/core/congestion_control/bbr2_sender.h" #include "gquiche/quic/core/quic_bandwidth.h" #include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" @@ -17,12 +18,12 @@ Bbr2StartupMode::Bbr2StartupMode(const Bbr2Sender* sender, Bbr2NetworkModel* model, QuicTime now) : Bbr2ModeBase(sender, model) { - // Clear some startup stats if |sender_->connection_stats_| has been used by - // another sender, which happens e.g. when QuicConnection switch send - // algorithms. - sender_->connection_stats_->slowstart_count = 1; - sender_->connection_stats_->slowstart_duration = QuicTimeAccumulator(); - sender_->connection_stats_->slowstart_duration.Start(now); + // Increment, instead of reset startup stats, so we don't lose data recorded + // before QuicConnection switched send algorithm to BBRv2. + ++sender_->connection_stats_->slowstart_count; + if (!sender_->connection_stats_->slowstart_duration.IsRunning()) { + sender_->connection_stats_->slowstart_duration.Start(now); + } // Enter() is never called for Startup, so the gains needs to be set here. model_->set_pacing_gain(Params().startup_pacing_gain); model_->set_cwnd_gain(Params().startup_cwnd_gain); @@ -46,23 +47,30 @@ Bbr2Mode Bbr2StartupMode::OnCongestionEvent( const AckedPacketVector& /*acked_packets*/, const LostPacketVector& /*lost_packets*/, const Bbr2CongestionEvent& congestion_event) { - if (!model_->full_bandwidth_reached() && congestion_event.end_of_round_trip) { - // TCP BBR always exits upon excessive losses. QUIC BBRv1 does not exits - // upon excessive losses, if enough bandwidth growth is observed. - Bbr2NetworkModel::BandwidthGrowth bw_growth = - model_->CheckBandwidthGrowth(congestion_event); - - if (Params().always_exit_startup_on_excess_loss || - (bw_growth == Bbr2NetworkModel::NO_GROWTH || - bw_growth == Bbr2NetworkModel::EXIT)) { - CheckExcessiveLosses(congestion_event); - } + if (model_->full_bandwidth_reached()) { + QUIC_BUG() << "In STARTUP, but full_bandwidth_reached is true."; + return Bbr2Mode::DRAIN; + } + if (!congestion_event.end_of_round_trip) { + return Bbr2Mode::STARTUP; + } + bool has_bandwidth_growth = model_->HasBandwidthGrowth(congestion_event); + if (Params().exit_startup_on_persistent_queue && !has_bandwidth_growth) { + QUIC_RELOADABLE_FLAG_COUNT(quic_bbr2_exit_startup_on_persistent_queue2); + model_->CheckPersistentQueue(congestion_event, Params().startup_cwnd_gain); + } + // TCP BBR always exits upon excessive losses. QUIC BBRv1 does not exit + // upon excessive losses, if enough bandwidth growth is observed or if the + // sample was app limited. + if (Params().always_exit_startup_on_excess_loss || + (!congestion_event.last_packet_send_state.is_app_limited && + !has_bandwidth_growth)) { + CheckExcessiveLosses(congestion_event); } if (Params().decrease_startup_pacing_at_end_of_round) { QUICHE_DCHECK_GT(model_->pacing_gain(), 0); - if (congestion_event.end_of_round_trip && - !congestion_event.last_sample_is_app_limited) { + if (!congestion_event.last_packet_send_state.is_app_limited) { // Multiply by startup_pacing_gain, so if the bandwidth doubles, // the pacing gain will be the full startup_pacing_gain. if (max_bw_at_round_beginning_ > QuicBandwidth::Zero()) { @@ -112,8 +120,9 @@ void Bbr2StartupMode::CheckExcessiveLosses( new_inflight_hi = model_->max_bytes_delivered_in_round(); } } - QUIC_DVLOG(3) << sender_ << " Exiting STARTUP due to loss. inflight_hi:" - << new_inflight_hi; + QUIC_DVLOG(3) << sender_ << " Exiting STARTUP due to loss at round " + << model_->RoundTripCount() + << ". inflight_hi:" << new_inflight_hi; // TODO(ianswett): Add a shared method to set inflight_hi in the model. model_->set_inflight_hi(new_inflight_hi); model_->set_full_bandwidth_reached(); diff --git a/gquiche/quic/core/congestion_control/bbr_sender.cc b/gquiche/quic/core/congestion_control/bbr_sender.cc index b8211031..85db3d3c 100644 --- a/gquiche/quic/core/congestion_control/bbr_sender.cc +++ b/gquiche/quic/core/congestion_control/bbr_sender.cc @@ -61,6 +61,9 @@ BbrSender::DebugState::DebugState(const BbrSender& sender) bandwidth_at_last_round(sender.bandwidth_at_last_round_), rounds_without_bandwidth_gain(sender.rounds_without_bandwidth_gain_), min_rtt(sender.min_rtt_), + smoothed_rtt(sender.rtt_stats_->smoothed_rtt()), + latest_rtt(sender.rtt_stats_->latest_rtt()), + mean_deviation(sender.rtt_stats_->mean_deviation()), min_rtt_timestamp(sender.min_rtt_timestamp_), recovery_state(sender.recovery_state_), recovery_window(sender.recovery_window_), @@ -278,6 +281,15 @@ void BbrSender::ApplyConnectionOptions( if (ContainsQuicTag(connection_options, kBSAO)) { sampler_.EnableOverestimateAvoidance(); } + if (ContainsQuicTag(connection_options, kBBRA)) { + sampler_.SetStartNewAggregationEpochAfterFullRound(true); + } + if (GetQuicReloadableFlag(quic_bbr_use_send_rate_in_max_ack_height_tracker) && + ContainsQuicTag(connection_options, kBBRB)) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_bbr_use_send_rate_in_max_ack_height_tracker, 1, 2); + sampler_.SetLimitMaxAckHeightTrackerBySendRate(true); + } } void BbrSender::AdjustNetworkParameters(const NetworkParams& params) { diff --git a/gquiche/quic/core/congestion_control/bbr_sender.h b/gquiche/quic/core/congestion_control/bbr_sender.h index 55fa50ab..91031d7d 100644 --- a/gquiche/quic/core/congestion_control/bbr_sender.h +++ b/gquiche/quic/core/congestion_control/bbr_sender.h @@ -81,6 +81,10 @@ class QUIC_EXPORT_PRIVATE BbrSender : public SendAlgorithmInterface { QuicTime::Delta min_rtt; QuicTime min_rtt_timestamp; + QuicTime::Delta latest_rtt; + QuicTime::Delta smoothed_rtt; + QuicTime::Delta mean_deviation; + RecoveryState recovery_state; QuicByteCount recovery_window; diff --git a/gquiche/quic/core/congestion_control/bbr_sender_test.cc b/gquiche/quic/core/congestion_control/bbr_sender_test.cc index 56b3f983..f27fcd17 100644 --- a/gquiche/quic/core/congestion_control/bbr_sender_test.cc +++ b/gquiche/quic/core/congestion_control/bbr_sender_test.cc @@ -10,6 +10,7 @@ #include #include "gquiche/quic/core/congestion_control/rtt_stats.h" +#include "gquiche/quic/core/crypto/crypto_protocol.h" #include "gquiche/quic/core/quic_bandwidth.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_types.h" @@ -109,8 +110,10 @@ class BbrSenderTest : public QuicTest { receiver_multiplexer_("Receiver multiplexer", {&receiver_, &competing_receiver_}) { rtt_stats_ = bbr_sender_.connection()->sent_packet_manager().GetRttStats(); + const int kTestMaxPacketSize = 1350; + bbr_sender_.connection()->SetMaxPacketLength(kTestMaxPacketSize); sender_ = SetupBbrSender(&bbr_sender_); - + SetConnectionOption(kBBRA); clock_ = simulator_.GetClock(); } @@ -322,6 +325,39 @@ TEST_F(BbrSenderTest, SimpleTransfer) { EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->smoothed_rtt(), 0.2f); } +TEST_F(BbrSenderTest, SimpleTransferBBRB) { + SetQuicReloadableFlag(quic_bbr_use_send_rate_in_max_ack_height_tracker, true); + SetConnectionOption(kBBRB); + CreateDefaultSetup(); + + // At startup make sure we are at the default. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + // At startup make sure we can send. + EXPECT_TRUE(sender_->CanSend(0)); + // And that window is un-affected. + EXPECT_EQ(kDefaultWindowTCP, sender_->GetCongestionWindow()); + + // Verify that Sender is in slow start. + EXPECT_TRUE(sender_->InSlowStart()); + + // Verify that pacing rate is based on the initial RTT. + QuicBandwidth expected_pacing_rate = QuicBandwidth::FromBytesAndTimeDelta( + 2.885 * kDefaultWindowTCP, rtt_stats_->initial_rtt()); + EXPECT_APPROX_EQ(expected_pacing_rate.ToBitsPerSecond(), + sender_->PacingRate(0).ToBitsPerSecond(), 0.01f); + + ASSERT_GE(kTestBdp, kDefaultWindowTCP + kDefaultTCPMSS); + + DoSimpleTransfer(12 * 1024 * 1024, QuicTime::Delta::FromSeconds(30)); + EXPECT_EQ(BbrSender::PROBE_BW, sender_->ExportDebugState().mode); + EXPECT_EQ(0u, bbr_sender_.connection()->GetStats().packets_lost); + EXPECT_FALSE(sender_->ExportDebugState().last_sample_is_app_limited); + + // The margin here is quite high, since there exists a possibility that the + // connection just exited high gain cycle. + EXPECT_APPROX_EQ(kTestRtt, rtt_stats_->smoothed_rtt(), 0.2f); +} + // Test a simple transfer in a situation when the buffer is less than BDP. TEST_F(BbrSenderTest, SimpleTransferSmallBuffer) { CreateSmallBufferSetup(); diff --git a/gquiche/quic/core/congestion_control/pacing_sender.cc b/gquiche/quic/core/congestion_control/pacing_sender.cc index a98c919d..3d700ae6 100644 --- a/gquiche/quic/core/congestion_control/pacing_sender.cc +++ b/gquiche/quic/core/congestion_control/pacing_sender.cc @@ -97,11 +97,18 @@ void PacingSender::OnPacketSent( GetQuicFlag(FLAGS_quic_lumpy_pacing_cwnd_fraction)) / kDefaultTCPMSS))); if (sender_->BandwidthEstimate() < - QuicBandwidth::FromKBitsPerSecond(1200)) { + QuicBandwidth::FromKBitsPerSecond( + GetQuicFlag(FLAGS_quic_lumpy_pacing_min_bandwidth_kbps))) { // Below 1.2Mbps, send 1 packet at once, because one full-sized packet // is about 10ms of queueing. lumpy_tokens_ = 1u; } + if (GetQuicReloadableFlag(quic_fix_pacing_sender_bursts) && + (bytes_in_flight + bytes) >= sender_->GetCongestionWindow()) { + QUIC_RELOADABLE_FLAG_COUNT(quic_fix_pacing_sender_bursts); + // Don't add lumpy_tokens if the congestion controller is CWND limited. + lumpy_tokens_ = 1u; + } } --lumpy_tokens_; if (pacing_limited_) { diff --git a/gquiche/quic/core/congestion_control/pacing_sender_test.cc b/gquiche/quic/core/congestion_control/pacing_sender_test.cc index cac12605..08bbafc6 100644 --- a/gquiche/quic/core/congestion_control/pacing_sender_test.cc +++ b/gquiche/quic/core/congestion_control/pacing_sender_test.cc @@ -7,6 +7,7 @@ #include #include +#include "gquiche/quic/core/quic_constants.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" @@ -75,42 +76,41 @@ class PacingSenderTest : public QuicTest { } void CheckPacketIsSentImmediately(HasRetransmittableData retransmittable_data, - QuicByteCount bytes_in_flight, + QuicByteCount prior_in_flight, bool in_recovery, - bool cwnd_limited, QuicPacketCount cwnd) { // In order for the packet to be sendable, the underlying sender must // permit it to be sent immediately. for (int i = 0; i < 2; ++i) { - EXPECT_CALL(*mock_sender_, CanSend(bytes_in_flight)) + EXPECT_CALL(*mock_sender_, CanSend(prior_in_flight)) .WillOnce(Return(true)); // Verify that the packet can be sent immediately. EXPECT_EQ(zero_time_, - pacing_sender_->TimeUntilSend(clock_.Now(), bytes_in_flight)); + pacing_sender_->TimeUntilSend(clock_.Now(), prior_in_flight)); } // Actually send the packet. - if (bytes_in_flight == 0) { + if (prior_in_flight == 0) { EXPECT_CALL(*mock_sender_, InRecovery()).WillOnce(Return(in_recovery)); } EXPECT_CALL(*mock_sender_, - OnPacketSent(clock_.Now(), bytes_in_flight, packet_number_, + OnPacketSent(clock_.Now(), prior_in_flight, packet_number_, kMaxOutgoingPacketSize, retransmittable_data)); EXPECT_CALL(*mock_sender_, GetCongestionWindow()) - .Times(AtMost(1)) .WillRepeatedly(Return(cwnd * kDefaultTCPMSS)); EXPECT_CALL(*mock_sender_, - CanSend(bytes_in_flight + kMaxOutgoingPacketSize)) + CanSend(prior_in_flight + kMaxOutgoingPacketSize)) .Times(AtMost(1)) - .WillRepeatedly(Return(!cwnd_limited)); - pacing_sender_->OnPacketSent(clock_.Now(), bytes_in_flight, + .WillRepeatedly(Return((prior_in_flight + kMaxOutgoingPacketSize) < + (cwnd * kDefaultTCPMSS))); + pacing_sender_->OnPacketSent(clock_.Now(), prior_in_flight, packet_number_++, kMaxOutgoingPacketSize, retransmittable_data); } void CheckPacketIsSentImmediately() { CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, kBytesInFlight, - false, false, 10); + false, 10); } void CheckPacketIsDelayed(QuicTime::Delta delay) { @@ -242,7 +242,7 @@ TEST_F(PacingSenderTest, InitialBurst) { // Next time TimeUntilSend is called with no bytes in flight, pacing should // allow a packet to be sent, and when it's sent, the tokens are refilled. - CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, false, 10); + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, 10); for (int i = 0; i < kInitialBurstPackets - 1; ++i) { CheckPacketIsSentImmediately(); } @@ -276,7 +276,7 @@ TEST_F(PacingSenderTest, InitialBurstNoRttMeasurement) { // Next time TimeUntilSend is called with no bytes in flight, the tokens // should be refilled and there should be no delay. - CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, false, 10); + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, 10); // Send 10 packets, and verify that they are not paced. for (int i = 0; i < kInitialBurstPackets - 1; ++i) { CheckPacketIsSentImmediately(); @@ -314,7 +314,7 @@ TEST_F(PacingSenderTest, FastSending) { // Next time TimeUntilSend is called with no bytes in flight, the tokens // should be refilled and there should be no delay. - CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, false, 10); + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, false, 10); for (int i = 0; i < kInitialBurstPackets - 1; ++i) { CheckPacketIsSentImmediately(); } @@ -348,10 +348,12 @@ TEST_F(PacingSenderTest, NoBurstEnteringRecovery) { // One packet is sent immediately, because of 1ms pacing granularity. CheckPacketIsSentImmediately(); // Ensure packets are immediately paced. - EXPECT_CALL(*mock_sender_, CanSend(kDefaultTCPMSS)).WillOnce(Return(true)); + EXPECT_CALL(*mock_sender_, CanSend(kMaxOutgoingPacketSize)) + .WillOnce(Return(true)); // Verify the next packet is paced and delayed 2ms due to granularity. - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(2), - pacing_sender_->TimeUntilSend(clock_.Now(), kDefaultTCPMSS)); + EXPECT_EQ( + QuicTime::Delta::FromMilliseconds(2), + pacing_sender_->TimeUntilSend(clock_.Now(), kMaxOutgoingPacketSize)); CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); } @@ -364,7 +366,7 @@ TEST_F(PacingSenderTest, NoBurstInRecovery) { UpdateRtt(); // Ensure only one packet is sent immediately and the rest are paced. - CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, true, false, 10); + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, 0, true, 10); CheckPacketIsSentImmediately(); CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); } @@ -385,8 +387,11 @@ TEST_F(PacingSenderTest, CwndLimited) { // Wake up on time. clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(2)); // After sending packet 3, cwnd is limited. - CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, kBytesInFlight, false, - true, 10); + // This test is slightly odd because bytes_in_flight is calculated using + // kMaxOutgoingPacketSize and CWND is calculated using kDefaultTCPMSS, + // which is 8 bytes larger, so 3 packets can be sent for a CWND of 2. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 2 * kMaxOutgoingPacketSize, false, 2); clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); // Verify pacing sender stops making up for lost time after sending packet 3. @@ -438,20 +443,23 @@ TEST_F(PacingSenderTest, LumpyPacingWithInitialBurstToken) { clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3)); CheckPacketIsSentImmediately(); // After sending packet 21, cwnd is limited. - CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, kBytesInFlight, false, - true, 10); + // This test is slightly odd because bytes_in_flight is calculated using + // kMaxOutgoingPacketSize and CWND is calculated using kDefaultTCPMSS, + // which is 8 bytes larger, so 21 packets can be sent for a CWND of 20. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 20 * kMaxOutgoingPacketSize, false, 20); clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); // Suppose cwnd size is 5, so that lumpy size becomes 2. CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, kBytesInFlight, false, - false, 5); + 5); CheckPacketIsSentImmediately(); // Packet 24 will be delayed 2ms. CheckPacketIsDelayed(QuicTime::Delta::FromMilliseconds(2)); } TEST_F(PacingSenderTest, NoLumpyPacingForLowBandwidthFlows) { - // Set lumpy size to be 3, and cwnd faction to 0.5 + // Set lumpy size to be 3, and cwnd fraction to 0.5 SetQuicFlag(FLAGS_quic_lumpy_pacing_size, 3); SetQuicFlag(FLAGS_quic_lumpy_pacing_cwnd_fraction, 0.5f); @@ -476,6 +484,43 @@ TEST_F(PacingSenderTest, NoLumpyPacingForLowBandwidthFlows) { } } +// Regression test for b/184471302 to ensure that ACKs received back-to-back +// don't cause bursts in sending. +TEST_F(PacingSenderTest, NoBurstsForLumpyPacingWithAckAggregation) { + // Configure pacing rate of 1 packet per millisecond. + QuicTime::Delta inter_packet_delay = QuicTime::Delta::FromMilliseconds(1); + InitPacingRate(kInitialBurstPackets, + QuicBandwidth::FromBytesAndTimeDelta(kMaxOutgoingPacketSize, + inter_packet_delay)); + UpdateRtt(); + + // Send kInitialBurstPackets packets, and verify that they are not paced. + for (int i = 0; i < kInitialBurstPackets; ++i) { + CheckPacketIsSentImmediately(); + } + // The last packet of the burst causes the sender to be CWND limited. + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 10 * kMaxOutgoingPacketSize, false, 10); + + if (GetQuicReloadableFlag(quic_fix_pacing_sender_bursts)) { + // The last sent packet made the connection CWND limited, so no lumpy tokens + // should be available. + EXPECT_EQ(0u, pacing_sender_->lumpy_tokens()); + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 10 * kMaxOutgoingPacketSize, false, 10); + EXPECT_EQ(0u, pacing_sender_->lumpy_tokens()); + CheckPacketIsDelayed(2 * inter_packet_delay); + } else { + EXPECT_EQ(1u, pacing_sender_->lumpy_tokens()); + // Repeatedly send single packets to make the sender CWND limited and + // observe that there's no pacing without the fix. + for (int i = 0; i < 10; ++i) { + CheckPacketIsSentImmediately(HAS_RETRANSMITTABLE_DATA, + 10 * kMaxOutgoingPacketSize, false, 10); + } + } +} + TEST_F(PacingSenderTest, IdealNextPacketSendTimeWithLumpyPacing) { // Set lumpy size to be 3, and cwnd faction to 0.5 SetQuicFlag(FLAGS_quic_lumpy_pacing_size, 3); diff --git a/gquiche/quic/core/congestion_control/rtt_stats.cc b/gquiche/quic/core/congestion_control/rtt_stats.cc index e12644e3..3efc58e6 100644 --- a/gquiche/quic/core/congestion_control/rtt_stats.cc +++ b/gquiche/quic/core/congestion_control/rtt_stats.cc @@ -29,8 +29,7 @@ RttStats::RttStats() mean_deviation_(QuicTime::Delta::Zero()), calculate_standard_deviation_(false), initial_rtt_(QuicTime::Delta::FromMilliseconds(kInitialRttMs)), - last_update_time_(QuicTime::Zero()), - ignore_max_ack_delay_(false) {} + last_update_time_(QuicTime::Zero()) {} void RttStats::ExpireSmoothedMetrics() { mean_deviation_ = std::max( @@ -63,17 +62,19 @@ void RttStats::UpdateRtt(QuicTime::Delta send_delta, QuicTime::Delta rtt_sample(send_delta); previous_srtt_ = smoothed_rtt_; - - if (ignore_max_ack_delay_) { - ack_delay = QuicTime::Delta::Zero(); - } // Correct for ack_delay if information received from the peer results in a // an RTT sample at least as large as min_rtt. Otherwise, only use the // send_delta. + // TODO(fayang): consider to ignore rtt_sample if rtt_sample < ack_delay and + // ack_delay is relatively large. if (rtt_sample > ack_delay) { if (rtt_sample - min_rtt_ >= ack_delay) { rtt_sample = rtt_sample - ack_delay; + } else { + QUIC_CODE_COUNT(quic_ack_delay_makes_rtt_sample_smaller_than_min_rtt); } + } else { + QUIC_CODE_COUNT(quic_ack_delay_greater_than_rtt_sample); } latest_rtt_ = rtt_sample; if (calculate_standard_deviation_) { @@ -138,7 +139,6 @@ void RttStats::CloneFrom(const RttStats& stats) { calculate_standard_deviation_ = stats.calculate_standard_deviation_; initial_rtt_ = stats.initial_rtt_; last_update_time_ = stats.last_update_time_; - ignore_max_ack_delay_ = stats.ignore_max_ack_delay_; } } // namespace quic diff --git a/gquiche/quic/core/congestion_control/rtt_stats.h b/gquiche/quic/core/congestion_control/rtt_stats.h index 7db07d9a..53b2f522 100644 --- a/gquiche/quic/core/congestion_control/rtt_stats.h +++ b/gquiche/quic/core/congestion_control/rtt_stats.h @@ -101,12 +101,6 @@ class QUIC_EXPORT_PRIVATE RttStats { QuicTime last_update_time() const { return last_update_time_; } - bool ignore_max_ack_delay() const { return ignore_max_ack_delay_; } - - void set_ignore_max_ack_delay(bool ignore_max_ack_delay) { - ignore_max_ack_delay_ = ignore_max_ack_delay; - } - void EnableStandardDeviationCalculation() { calculate_standard_deviation_ = true; } @@ -130,8 +124,6 @@ class QUIC_EXPORT_PRIVATE RttStats { bool calculate_standard_deviation_; QuicTime::Delta initial_rtt_; QuicTime last_update_time_; - // Whether to ignore the peer's max ack delay. - bool ignore_max_ack_delay_; }; } // namespace quic diff --git a/gquiche/quic/core/congestion_control/rtt_stats_test.cc b/gquiche/quic/core/congestion_control/rtt_stats_test.cc index aa7f31f2..d72c3c4b 100644 --- a/gquiche/quic/core/congestion_control/rtt_stats_test.cc +++ b/gquiche/quic/core/congestion_control/rtt_stats_test.cc @@ -56,34 +56,6 @@ TEST_F(RttStatsTest, SmoothedRtt) { rtt_stats_.smoothed_rtt()); } -TEST_F(RttStatsTest, SmoothedRttIgnoreAckDelay) { - rtt_stats_.set_ignore_max_ack_delay(true); - // Verify that ack_delay is ignored in the first measurement. - rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(300), - QuicTime::Delta::FromMilliseconds(100), - QuicTime::Zero()); - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.latest_rtt()); - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.smoothed_rtt()); - // Verify that a plausible ack delay increases the max ack delay. - rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(300), - QuicTime::Delta::FromMilliseconds(100), - QuicTime::Zero()); - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.latest_rtt()); - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.smoothed_rtt()); - // Verify that Smoothed RTT includes max ack delay if it's reasonable. - rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(300), - QuicTime::Delta::FromMilliseconds(50), QuicTime::Zero()); - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.latest_rtt()); - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), rtt_stats_.smoothed_rtt()); - // Verify that large erroneous ack_delay does not change Smoothed RTT. - rtt_stats_.UpdateRtt(QuicTime::Delta::FromMilliseconds(200), - QuicTime::Delta::FromMilliseconds(300), - QuicTime::Zero()); - EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), rtt_stats_.latest_rtt()); - EXPECT_EQ(QuicTime::Delta::FromMicroseconds(287500), - rtt_stats_.smoothed_rtt()); -} - // Ensure that the potential rounding artifacts in EWMA calculation do not cause // the SRTT to drift too far from the exact value. TEST_F(RttStatsTest, SmoothedRttStability) { diff --git a/gquiche/quic/core/congestion_control/send_algorithm_test.cc b/gquiche/quic/core/congestion_control/send_algorithm_test.cc index 86fa8f1a..deccdf8a 100644 --- a/gquiche/quic/core/congestion_control/send_algorithm_test.cc +++ b/gquiche/quic/core/congestion_control/send_algorithm_test.cc @@ -175,6 +175,8 @@ class SendAlgorithmTest : public QuicTestWithParam { quic_sender_.RecordTrace(); QuicConnectionPeer::SetSendAlgorithm(quic_sender_.connection(), sender_); + const int kTestMaxPacketSize = 1350; + quic_sender_.connection()->SetMaxPacketLength(kTestMaxPacketSize); clock_ = simulator_.GetClock(); simulator_.set_random_generator(&random_); diff --git a/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc b/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc index 1a3409c0..6e02ebec 100644 --- a/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc +++ b/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.cc @@ -56,6 +56,15 @@ TcpCubicSenderBytes::TcpCubicSenderBytes( TcpCubicSenderBytes::~TcpCubicSenderBytes() {} +TcpCubicSenderBytes::DebugState::DebugState(const TcpCubicSenderBytes& sender) + : min_rtt(sender.rtt_stats_->smoothed_rtt()), + latest_rtt(sender.rtt_stats_->latest_rtt()), + smoothed_rtt(sender.rtt_stats_->smoothed_rtt()), + mean_deviation(sender.rtt_stats_->mean_deviation()), + bandwidth_est(sender.BandwidthEstimate()) {} + +TcpCubicSenderBytes::DebugState::DebugState(const DebugState& state) = default; + void TcpCubicSenderBytes::SetFromConfig(const QuicConfig& config, Perspective perspective) { if (perspective == Perspective::IS_SERVER) { @@ -257,6 +266,10 @@ std::string TcpCubicSenderBytes::GetDebugState() const { return ""; } +TcpCubicSenderBytes::DebugState TcpCubicSenderBytes::ExportDebugState() const { + return DebugState(*this); +} + void TcpCubicSenderBytes::OnApplicationLimited( QuicByteCount /*bytes_in_flight*/) {} diff --git a/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h b/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h index 039bd567..56526229 100644 --- a/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h +++ b/gquiche/quic/core/congestion_control/tcp_cubic_sender_bytes.h @@ -42,6 +42,15 @@ class QUIC_EXPORT_PRIVATE TcpCubicSenderBytes : public SendAlgorithmInterface { TcpCubicSenderBytes(const TcpCubicSenderBytes&) = delete; TcpCubicSenderBytes& operator=(const TcpCubicSenderBytes&) = delete; ~TcpCubicSenderBytes() override; + struct QUIC_EXPORT_PRIVATE DebugState { + explicit DebugState(const TcpCubicSenderBytes& sender); + DebugState(const DebugState& state); + QuicTime::Delta min_rtt; + QuicTime::Delta latest_rtt; + QuicTime::Delta smoothed_rtt; + QuicTime::Delta mean_deviation; + QuicBandwidth bandwidth_est; + }; // Start implementation of SendAlgorithmInterface. void SetFromConfig(const QuicConfig& config, @@ -78,7 +87,8 @@ class QUIC_EXPORT_PRIVATE TcpCubicSenderBytes : public SendAlgorithmInterface { void OnApplicationLimited(QuicByteCount bytes_in_flight) override; void PopulateConnectionStats(QuicConnectionStats* /*stats*/) const override {} // End implementation of SendAlgorithmInterface. - + + DebugState ExportDebugState() const; QuicByteCount min_congestion_window() const { return min_congestion_window_; } protected: diff --git a/gquiche/quic/core/congestion_control/windowed_filter.h b/gquiche/quic/core/congestion_control/windowed_filter.h index 13bf2a85..6485ed44 100644 --- a/gquiche/quic/core/congestion_control/windowed_filter.h +++ b/gquiche/quic/core/congestion_control/windowed_filter.h @@ -71,6 +71,7 @@ class QUIC_EXPORT_PRIVATE WindowedFilter { WindowedFilter(TimeDeltaT window_length, T zero_value, TimeT zero_time) : window_length_(window_length), zero_value_(zero_value), + zero_time_(zero_time), estimates_{Sample(zero_value_, zero_time), Sample(zero_value_, zero_time), Sample(zero_value_, zero_time)} {} @@ -138,6 +139,8 @@ class QUIC_EXPORT_PRIVATE WindowedFilter { Sample(new_sample, new_time); } + void Clear() { Reset(zero_value_, zero_time_); } + T GetBest() const { return estimates_[0].sample; } T GetSecondBest() const { return estimates_[1].sample; } T GetThirdBest() const { return estimates_[2].sample; } @@ -152,6 +155,7 @@ class QUIC_EXPORT_PRIVATE WindowedFilter { TimeDeltaT window_length_; // Time length of window. T zero_value_; // Uninitialized value of T. + TimeT zero_time_; // Uninitialized value of TimeT. Sample estimates_[3]; // Best estimate is element 0. }; diff --git a/gquiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc b/gquiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc index ce2d3628..daae810e 100644 --- a/gquiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc +++ b/gquiche/quic/core/crypto/aes_128_gcm_12_decrypter_test.cc @@ -13,7 +13,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc b/gquiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc index 8b430151..0aa6578a 100644 --- a/gquiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc +++ b/gquiche/quic/core/crypto/aes_128_gcm_12_encrypter_test.cc @@ -13,7 +13,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc b/gquiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc index c338491c..ddbc1522 100644 --- a/gquiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc +++ b/gquiche/quic/core/crypto/aes_128_gcm_decrypter_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc b/gquiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc index 70e9eed2..cfdc46b2 100644 --- a/gquiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc +++ b/gquiche/quic/core/crypto/aes_128_gcm_encrypter_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc b/gquiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc index 7bb4d373..812a3b3a 100644 --- a/gquiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc +++ b/gquiche/quic/core/crypto/aes_256_gcm_decrypter_test.cc @@ -13,7 +13,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc b/gquiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc index 7abe203a..0dd6a35a 100644 --- a/gquiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc +++ b/gquiche/quic/core/crypto/aes_256_gcm_encrypter_test.cc @@ -13,7 +13,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/cert_compressor.cc b/gquiche/quic/core/crypto/cert_compressor.cc index 31677841..9f355c51 100644 --- a/gquiche/quic/core/crypto/cert_compressor.cc +++ b/gquiche/quic/core/crypto/cert_compressor.cc @@ -11,6 +11,9 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/quic_utils.h" +#include "gquiche/quic/platform/api/quic_bug_tracker.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" +#include "gquiche/quic/platform/api/quic_flags.h" #include "zlib.h" namespace quic { @@ -215,8 +218,11 @@ std::vector MatchCerts(const std::vector& certs, } } - if (common_sets && common_sets->MatchCert(*i, client_common_set_hashes, - &entry.set_hash, &entry.index)) { + if (GetQuicRestartFlag(quic_no_common_cert_set)) { + QUIC_RESTART_FLAG_COUNT(quic_no_common_cert_set); + } else if (common_sets && + common_sets->MatchCert(*i, client_common_set_hashes, + &entry.set_hash, &entry.index)) { entry.type = CertEntry::COMMON; entries.push_back(entry); continue; @@ -243,6 +249,8 @@ size_t CertEntriesSize(const std::vector& entries) { entries_size += sizeof(uint64_t); break; case CertEntry::COMMON: + QUIC_BUG_IF(unexpected_common_cert_entry_1, + GetQuicRestartFlag(quic_no_common_cert_set)); entries_size += sizeof(uint64_t) + sizeof(uint32_t); break; } @@ -266,6 +274,8 @@ void SerializeCertEntries(uint8_t* out, const std::vector& entries) { out += sizeof(uint64_t); break; case CertEntry::COMMON: + QUIC_BUG_IF(unexpected_common_cert_entry_2, + GetQuicRestartFlag(quic_no_common_cert_set)); // Assumes a little-endian machine. memcpy(out, &i->set_hash, sizeof(i->set_hash)); out += sizeof(i->set_hash); @@ -382,6 +392,10 @@ bool ParseEntries(absl::string_view* in_out, break; } case CertEntry::COMMON: { + if (GetQuicRestartFlag(quic_no_common_cert_set)) { + // Client only. No flag count. + return false; + } if (!common_sets) { return false; } @@ -628,7 +642,10 @@ bool CertCompressor::DecompressChain( uncompressed.remove_prefix(cert_len); break; case CertEntry::CACHED: + break; case CertEntry::COMMON: + QUIC_BUG_IF(unexpected_common_cert_entry_3, + GetQuicRestartFlag(quic_no_common_cert_set)); break; } } diff --git a/gquiche/quic/core/crypto/cert_compressor_test.cc b/gquiche/quic/core/crypto/cert_compressor_test.cc index 0db7fd1a..f91ce0c6 100644 --- a/gquiche/quic/core/crypto/cert_compressor_test.cc +++ b/gquiche/quic/core/crypto/cert_compressor_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/crypto_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { namespace test { @@ -57,12 +56,18 @@ TEST_F(CertCompressorTest, Common) { absl::string_view(reinterpret_cast(&set_hash), sizeof(set_hash)), absl::string_view(), common_sets.get()); - EXPECT_EQ( - "03" /* common */ - "2a00000000000000" /* set hash 42 */ - "01000000" /* index 1 */ - "00" /* end of list */, - absl::BytesToHexString(compressed)); + if (!GetQuicRestartFlag(quic_no_common_cert_set)) { + EXPECT_EQ( + "03" /* common */ + "2a00000000000000" /* set hash 42 */ + "01000000" /* index 1 */ + "00" /* end of list */, + absl::BytesToHexString(compressed)); + } else { + ASSERT_GE(compressed.size(), 2u); + // 01 is the prefix for a zlib "compressed" cert not common or cached. + EXPECT_EQ("0100", absl::BytesToHexString(compressed.substr(0, 2))); + } std::vector chain2, cached_certs; ASSERT_TRUE(CertCompressor::DecompressChain(compressed, cached_certs, diff --git a/gquiche/quic/core/crypto/certificate_view.cc b/gquiche/quic/core/crypto/certificate_view.cc index cf288ff8..a8a45ade 100644 --- a/gquiche/quic/core/crypto/certificate_view.cc +++ b/gquiche/quic/core/crypto/certificate_view.cc @@ -30,9 +30,9 @@ #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_ip_address.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/platform/api/quiche_time_utils.h" #include "gquiche/common/quiche_data_reader.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace { @@ -586,7 +586,7 @@ std::unique_ptr CertificatePrivateKey::LoadPemFromStream( } std::string CertificatePrivateKey::Sign(absl::string_view input, - uint16_t signature_algorithm) { + uint16_t signature_algorithm) const { if (!ValidForSignatureAlgorithm(signature_algorithm)) { QUIC_BUG(quic_bug_10640_2) << "Mismatch between the requested signature algorithm and the " @@ -626,12 +626,13 @@ std::string CertificatePrivateKey::Sign(absl::string_view input, return output; } -bool CertificatePrivateKey::MatchesPublicKey(const CertificateView& view) { +bool CertificatePrivateKey::MatchesPublicKey( + const CertificateView& view) const { return EVP_PKEY_cmp(view.public_key(), private_key_.get()) == 1; } bool CertificatePrivateKey::ValidForSignatureAlgorithm( - uint16_t signature_algorithm) { + uint16_t signature_algorithm) const { return PublicKeyTypeFromSignatureAlgorithm(signature_algorithm) == PublicKeyTypeFromKey(private_key_.get()); } diff --git a/gquiche/quic/core/crypto/certificate_view.h b/gquiche/quic/core/crypto/certificate_view.h index 17a0e2bf..a214908e 100644 --- a/gquiche/quic/core/crypto/certificate_view.h +++ b/gquiche/quic/core/crypto/certificate_view.h @@ -105,17 +105,17 @@ class QUIC_EXPORT_PRIVATE CertificatePrivateKey { std::istream* input); // |signature_algorithm| is a TLS signature algorithm ID. - std::string Sign(absl::string_view input, uint16_t signature_algorithm); + std::string Sign(absl::string_view input, uint16_t signature_algorithm) const; // Verifies that the private key in question matches the public key of the // certificate |view|. - bool MatchesPublicKey(const CertificateView& view); + bool MatchesPublicKey(const CertificateView& view) const; // Verifies that the private key can be used with the specified TLS signature // algorithm. - bool ValidForSignatureAlgorithm(uint16_t signature_algorithm); + bool ValidForSignatureAlgorithm(uint16_t signature_algorithm) const; - EVP_PKEY* private_key() { return private_key_.get(); } + EVP_PKEY* private_key() const { return private_key_.get(); } private: CertificatePrivateKey() = default; diff --git a/gquiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc b/gquiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc index a8e82a6c..c8b8bb5c 100644 --- a/gquiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc +++ b/gquiche/quic/core/crypto/chacha20_poly1305_decrypter_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc b/gquiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc index 9e6aeea6..11ff4c6d 100644 --- a/gquiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc +++ b/gquiche/quic/core/crypto/chacha20_poly1305_encrypter_test.cc @@ -8,12 +8,12 @@ #include #include "absl/base/macros.h" +#include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "gquiche/quic/core/crypto/chacha20_poly1305_decrypter.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc b/gquiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc index da1cff2a..08c5e40c 100644 --- a/gquiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc +++ b/gquiche/quic/core/crypto/chacha20_poly1305_tls_decrypter_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc b/gquiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc index b2f6ba2f..ce803284 100644 --- a/gquiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc +++ b/gquiche/quic/core/crypto/chacha20_poly1305_tls_encrypter_test.cc @@ -14,7 +14,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace { diff --git a/gquiche/quic/core/crypto/client_proof_source.cc b/gquiche/quic/core/crypto/client_proof_source.cc new file mode 100644 index 00000000..d791e975 --- /dev/null +++ b/gquiche/quic/core/crypto/client_proof_source.cc @@ -0,0 +1,62 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/crypto/client_proof_source.h" + +#include "absl/strings/match.h" + +namespace quic { + +bool DefaultClientProofSource::AddCertAndKey( + std::vector server_hostnames, + QuicReferenceCountedPointer chain, + CertificatePrivateKey private_key) { + if (!ValidateCertAndKey(chain, private_key)) { + return false; + } + + auto cert_and_key = + std::make_shared(std::move(chain), std::move(private_key)); + for (const std::string& domain : server_hostnames) { + cert_and_keys_[domain] = cert_and_key; + } + return true; +} + +const ClientProofSource::CertAndKey* DefaultClientProofSource::GetCertAndKey( + absl::string_view hostname) const { + const CertAndKey* result = LookupExact(hostname); + if (result != nullptr || hostname == "*") { + return result; + } + + // Either a full or a wildcard domain lookup failed. In the former case, + // derive the wildcard domain and look it up. + if (hostname.size() > 1 && !absl::StartsWith(hostname, "*.")) { + auto dot_pos = hostname.find('.'); + if (dot_pos != std::string::npos) { + std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos)); + const CertAndKey* result = LookupExact(wildcard); + if (result != nullptr) { + return result; + } + } + } + + // Return default cert, if any. + return LookupExact("*"); +} + +const ClientProofSource::CertAndKey* DefaultClientProofSource::LookupExact( + absl::string_view map_key) const { + const auto it = cert_and_keys_.find(map_key); + QUIC_DVLOG(1) << "LookupExact(" << map_key + << ") found:" << (it != cert_and_keys_.end()); + if (it != cert_and_keys_.end()) { + return it->second.get(); + } + return nullptr; +} + +} // namespace quic diff --git a/gquiche/quic/core/crypto/client_proof_source.h b/gquiche/quic/core/crypto/client_proof_source.h new file mode 100644 index 00000000..4e8d2ad8 --- /dev/null +++ b/gquiche/quic/core/crypto/client_proof_source.h @@ -0,0 +1,70 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_CLIENT_PROOF_SOURCE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_CLIENT_PROOF_SOURCE_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "gquiche/quic/core/crypto/certificate_view.h" +#include "gquiche/quic/core/crypto/proof_source.h" + +namespace quic { + +// ClientProofSource is the interface for a QUIC client to provide client certs +// and keys based on server hostname. It is only used by TLS handshakes. +class QUIC_EXPORT_PRIVATE ClientProofSource { + public: + using Chain = ProofSource::Chain; + + virtual ~ClientProofSource() {} + + struct QUIC_EXPORT_PRIVATE CertAndKey { + CertAndKey(QuicReferenceCountedPointer chain, + CertificatePrivateKey private_key) + : chain(std::move(chain)), private_key(std::move(private_key)) {} + + QuicReferenceCountedPointer chain; + CertificatePrivateKey private_key; + }; + + // Get the client certificate to be sent to the server with |server_hostname| + // and its corresponding private key. It returns nullptr if the cert and key + // can not be found. + // + // |server_hostname| is typically a full domain name(www.foo.com), but it + // could also be a wildcard domain(*.foo.com), or a "*" which will return the + // default cert. + virtual const CertAndKey* GetCertAndKey( + absl::string_view server_hostname) const = 0; +}; + +// DefaultClientProofSource is an implementation that simply keeps an in memory +// map of server hostnames to certs. +class QUIC_EXPORT_PRIVATE DefaultClientProofSource : public ClientProofSource { + public: + ~DefaultClientProofSource() override {} + + // Associate all hostnames in |server_hostnames| with {|chain|,|private_key|}. + // Elements of |server_hostnames| can be full domain names(www.foo.com), + // wildcard domains(*.foo.com), or "*" which means the given cert chain is the + // default one. + // If any element of |server_hostnames| is already associated with a cert + // chain, it will be updated to be associated with the new cert chain. + bool AddCertAndKey(std::vector server_hostnames, + QuicReferenceCountedPointer chain, + CertificatePrivateKey private_key); + + // ClientProofSource implementation + const CertAndKey* GetCertAndKey(absl::string_view hostname) const override; + + private: + const CertAndKey* LookupExact(absl::string_view map_key) const; + absl::flat_hash_map> cert_and_keys_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_CLIENT_PROOF_SOURCE_H_ diff --git a/gquiche/quic/core/crypto/client_proof_source_test.cc b/gquiche/quic/core/crypto/client_proof_source_test.cc new file mode 100644 index 00000000..578c41c9 --- /dev/null +++ b/gquiche/quic/core/crypto/client_proof_source_test.cc @@ -0,0 +1,212 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/crypto/client_proof_source.h" + +#include "gquiche/quic/platform/api/quic_expect_bug.h" +#include "gquiche/quic/platform/api/quic_test.h" +#include "gquiche/quic/test_tools/test_certificates.h" + +namespace quic { +namespace test { + +QuicReferenceCountedPointer TestCertChain() { + return QuicReferenceCountedPointer( + new ClientProofSource::Chain({std::string(kTestCertificate)})); +} + +CertificatePrivateKey TestPrivateKey() { + CBS private_key_cbs; + CBS_init(&private_key_cbs, + reinterpret_cast(kTestCertificatePrivateKey.data()), + kTestCertificatePrivateKey.size()); + + return CertificatePrivateKey( + bssl::UniquePtr(EVP_parse_private_key(&private_key_cbs))); +} + +const ClientProofSource::CertAndKey* TestCertAndKey() { + static const ClientProofSource::CertAndKey cert_and_key(TestCertChain(), + TestPrivateKey()); + return &cert_and_key; +} + +QuicReferenceCountedPointer NullCertChain() { + return QuicReferenceCountedPointer(); +} + +QuicReferenceCountedPointer EmptyCertChain() { + return QuicReferenceCountedPointer( + new ClientProofSource::Chain(std::vector())); +} + +QuicReferenceCountedPointer BadCertChain() { + return QuicReferenceCountedPointer( + new ClientProofSource::Chain({"This is the content of a bad cert."})); +} + +CertificatePrivateKey EmptyPrivateKey() { + return CertificatePrivateKey(bssl::UniquePtr(EVP_PKEY_new())); +} + +#define VERIFY_CERT_AND_KEY_MATCHES(lhs, rhs) \ + do { \ + SCOPED_TRACE(testing::Message()); \ + VerifyCertAndKeyMatches(lhs, rhs); \ + } while (0) + +void VerifyCertAndKeyMatches(const ClientProofSource::CertAndKey* lhs, + const ClientProofSource::CertAndKey* rhs) { + if (lhs == rhs) { + return; + } + + if (lhs == nullptr) { + ADD_FAILURE() << "lhs is nullptr, but rhs is not"; + return; + } + + if (rhs == nullptr) { + ADD_FAILURE() << "rhs is nullptr, but lhs is not"; + return; + } + + if (1 != EVP_PKEY_cmp(lhs->private_key.private_key(), + rhs->private_key.private_key())) { + ADD_FAILURE() << "Private keys mismatch"; + return; + } + + const ClientProofSource::Chain* lhs_chain = lhs->chain.get(); + const ClientProofSource::Chain* rhs_chain = rhs->chain.get(); + + if (lhs_chain == rhs_chain) { + return; + } + + if (lhs_chain == nullptr) { + ADD_FAILURE() << "lhs->chain is nullptr, but rhs->chain is not"; + return; + } + + if (rhs_chain == nullptr) { + ADD_FAILURE() << "rhs->chain is nullptr, but lhs->chain is not"; + return; + } + + if (lhs_chain->certs.size() != rhs_chain->certs.size()) { + ADD_FAILURE() << "Cert chain length differ. lhs:" << lhs_chain->certs.size() + << ", rhs:" << rhs_chain->certs.size(); + return; + } + + for (size_t i = 0; i < lhs_chain->certs.size(); ++i) { + if (lhs_chain->certs[i] != rhs_chain->certs[i]) { + ADD_FAILURE() << "The " << i << "-th certs differ."; + return; + } + } + + // All good. +} + +TEST(DefaultClientProofSource, FullDomain) { + DefaultClientProofSource proof_source; + ASSERT_TRUE(proof_source.AddCertAndKey({"www.google.com"}, TestCertChain(), + TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + EXPECT_EQ(proof_source.GetCertAndKey("*.google.com"), nullptr); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, WildcardDomain) { + DefaultClientProofSource proof_source; + ASSERT_TRUE(proof_source.AddCertAndKey({"*.google.com"}, TestCertChain(), + TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*.google.com"), + TestCertAndKey()); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, DefaultDomain) { + DefaultClientProofSource proof_source; + ASSERT_TRUE( + proof_source.AddCertAndKey({"*"}, TestCertChain(), TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*"), + TestCertAndKey()); +} + +TEST(DefaultClientProofSource, FullAndWildcard) { + DefaultClientProofSource proof_source; + ASSERT_TRUE(proof_source.AddCertAndKey({"www.google.com", "*.google.com"}, + TestCertChain(), TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("foo.google.com"), + TestCertAndKey()); + EXPECT_EQ(proof_source.GetCertAndKey("www.example.com"), nullptr); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, FullWildcardAndDefault) { + DefaultClientProofSource proof_source; + ASSERT_TRUE( + proof_source.AddCertAndKey({"www.google.com", "*.google.com", "*"}, + TestCertChain(), TestPrivateKey())); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("foo.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("www.example.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*.google.com"), + TestCertAndKey()); + VERIFY_CERT_AND_KEY_MATCHES(proof_source.GetCertAndKey("*"), + TestCertAndKey()); +} + +TEST(DefaultClientProofSource, EmptyCerts) { + DefaultClientProofSource proof_source; + bool ok; + EXPECT_QUIC_BUG( + ok = proof_source.AddCertAndKey({"*"}, NullCertChain(), TestPrivateKey()), + "Certificate chain is empty"); + ASSERT_FALSE(ok); + + EXPECT_QUIC_BUG(ok = proof_source.AddCertAndKey({"*"}, EmptyCertChain(), + TestPrivateKey()), + "Certificate chain is empty"); + ASSERT_FALSE(ok); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, BadCerts) { + DefaultClientProofSource proof_source; + bool ok; + EXPECT_QUIC_BUG( + ok = proof_source.AddCertAndKey({"*"}, BadCertChain(), TestPrivateKey()), + "Unabled to parse leaf certificate"); + ASSERT_FALSE(ok); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +TEST(DefaultClientProofSource, KeyMismatch) { + DefaultClientProofSource proof_source; + bool ok; + EXPECT_QUIC_BUG(ok = proof_source.AddCertAndKey( + {"www.google.com"}, TestCertChain(), EmptyPrivateKey()), + "Private key does not match the leaf certificate"); + ASSERT_FALSE(ok); + EXPECT_EQ(proof_source.GetCertAndKey("*"), nullptr); +} + +} // namespace test +} // namespace quic diff --git a/gquiche/quic/core/crypto/crypto_handshake_message.cc b/gquiche/quic/core/crypto/crypto_handshake_message.cc index 5dca5eca..6c0f70dd 100644 --- a/gquiche/quic/core/crypto/crypto_handshake_message.cc +++ b/gquiche/quic/core/crypto/crypto_handshake_message.cc @@ -16,7 +16,6 @@ #include "gquiche/quic/core/crypto/crypto_utils.h" #include "gquiche/quic/core/quic_socket_address_coder.h" #include "gquiche/quic/core/quic_utils.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/common/quiche_endian.h" namespace quic { @@ -171,7 +170,7 @@ bool CryptoHandshakeMessage::GetStringPiece(QuicTag tag, } bool CryptoHandshakeMessage::HasStringPiece(QuicTag tag) const { - return QuicContainsKey(tag_value_map_, tag); + return tag_value_map_.find(tag) != tag_value_map_.end(); } QuicErrorCode CryptoHandshakeMessage::GetNthValue24( diff --git a/gquiche/quic/core/crypto/crypto_message_printer_bin.cc b/gquiche/quic/core/crypto/crypto_message_printer_bin.cc index 3ee6ff33..5b071775 100644 --- a/gquiche/quic/core/crypto/crypto_message_printer_bin.cc +++ b/gquiche/quic/core/crypto/crypto_message_printer_bin.cc @@ -13,7 +13,6 @@ #include "absl/strings/escaping.h" #include "gquiche/quic/core/crypto/crypto_framer.h" #include "gquiche/quic/core/quic_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using std::cerr; using std::cout; diff --git a/gquiche/quic/core/crypto/crypto_protocol.h b/gquiche/quic/core/crypto/crypto_protocol.h index 24917e88..a47ece60 100644 --- a/gquiche/quic/core/crypto/crypto_protocol.h +++ b/gquiche/quic/core/crypto/crypto_protocol.h @@ -28,7 +28,8 @@ using ServerConfigID = std::string; // The following tags have been deprecated and should not be reused: // "1CON", "BBQ4", "NCON", "RCID", "SREJ", "TBKP", "TB10", "SCLS", "SMHL", -// "QNZR", "B2HI", "H2PR", "FIFO", "LIFO", "RRWS" +// "QNZR", "B2HI", "H2PR", "FIFO", "LIFO", "RRWS", "QNSP", "B2CL", "CHSP", +// "BPTE", "ACKD", "AKD2", "AKD4", "MAD1", "MAD4", "MAD5", "ACD0", "ACKQ" // clang-format off const QuicTag kCHLO = TAG('C', 'H', 'L', 'O'); // Client hello @@ -98,6 +99,11 @@ const QuicTag kBBR3 = TAG('B', 'B', 'R', '3'); // Fully drain the queue once const QuicTag kBBR4 = TAG('B', 'B', 'R', '4'); // 20 RTT ack aggregation const QuicTag kBBR5 = TAG('B', 'B', 'R', '5'); // 40 RTT ack aggregation const QuicTag kBBR9 = TAG('B', 'B', 'R', '9'); // DEPRECATED +const QuicTag kBBRA = TAG('B', 'B', 'R', 'A'); // Starts a new ack aggregation + // epoch if a full round has + // passed +const QuicTag kBBRB = TAG('B', 'B', 'R', 'B'); // Use send rate in BBR's + // MaxAckHeightTracker const QuicTag kBBRS = TAG('B', 'B', 'R', 'S'); // DEPRECATED const QuicTag kBBQ1 = TAG('B', 'B', 'Q', '1'); // BBR with lower 2.77 STARTUP // pacing and CWND gain. @@ -115,6 +121,8 @@ const QuicTag kBBQ7 = TAG('B', 'B', 'Q', '7'); // Reduce bw_lo by const QuicTag kBBQ8 = TAG('B', 'B', 'Q', '8'); // Reduce bw_lo by // bw_lo * bytes_lost/inflight const QuicTag kBBQ9 = TAG('B', 'B', 'Q', '9'); // Reduce bw_lo by +const QuicTag kBBQ0 = TAG('B', 'B', 'Q', '0'); // Increase bytes_acked in + // PROBE_UP when app limited. // bw_lo * bytes_lost/cwnd const QuicTag kRENO = TAG('R', 'E', 'N', 'O'); // Reno Congestion Control const QuicTag kTPCC = TAG('P', 'C', 'C', '\0'); // Performance-Oriented @@ -127,14 +135,12 @@ const QuicTag kIW50 = TAG('I', 'W', '5', '0'); // Force ICWND to 50 const QuicTag kB2ON = TAG('B', '2', 'O', 'N'); // Enable BBRv2 const QuicTag kB2NA = TAG('B', '2', 'N', 'A'); // For BBRv2, do not add ack // height to queueing threshold -const QuicTag kB2NE = TAG('B', '2', 'N', 'E'); // For BBRv2, do not exit - // STARTUP if there's enough - // bandwidth growth +const QuicTag kB2NE = TAG('B', '2', 'N', 'E'); // For BBRv2, always exit + // STARTUP on loss, even if + // bandwidth growth exceeds + // threshold. const QuicTag kB2RP = TAG('B', '2', 'R', 'P'); // For BBRv2, run PROBE_RTT on // the regular schedule -const QuicTag kB2CL = TAG('B', '2', 'C', 'L'); // For BBRv2, allow PROBE_BW - // cwnd to be below BDP + ack - // height. const QuicTag kB2LO = TAG('B', '2', 'L', 'O'); // Ignore inflight_lo in BBR2 const QuicTag kB2HR = TAG('B', '2', 'H', 'R'); // 15% inflight_hi headroom. const QuicTag kB2SL = TAG('B', '2', 'S', 'L'); // When exiting STARTUP due to @@ -152,6 +158,20 @@ const QuicTag kBSAO = TAG('B', 'S', 'A', 'O'); // Avoid Overestimation in // aggregation const QuicTag kB2DL = TAG('B', '2', 'D', 'L'); // Increase inflight_hi based // on delievered, not inflight. +const QuicTag kB201 = TAG('B', '2', '0', '1'); // In PROBE_UP, check if cwnd + // limited before aggregation + // epoch, instead of ack event. +const QuicTag kB202 = TAG('B', '2', '0', '2'); // Do not exit PROBE_UP if + // inflight dips below 1.25*BW. +const QuicTag kB203 = TAG('B', '2', '0', '3'); // Ignore inflight_hi until + // PROBE_UP is exited. +const QuicTag kB204 = TAG('B', '2', '0', '4'); // Reduce extra acked when + // MaxBW incrases. +const QuicTag kB205 = TAG('B', '2', '0', '5'); // Add extra acked to CWND in + // STARTUP. +const QuicTag kB206 = TAG('B', '2', '0', '6'); // Exit STARTUP after 2 losses. +const QuicTag kB207 = TAG('B', '2', '0', '7'); // Exit STARTUP on persistent + // queue const QuicTag kNTLP = TAG('N', 'T', 'L', 'P'); // No tail loss probe const QuicTag k1TLP = TAG('1', 'T', 'L', 'P'); // 1 tail loss probe const QuicTag k1RTO = TAG('1', 'R', 'T', 'O'); // Send 1 packet upon RTO @@ -164,24 +184,13 @@ const QuicTag kMIN4 = TAG('M', 'I', 'N', '4'); // Min CWND of 4 packets, const QuicTag kTLPR = TAG('T', 'L', 'P', 'R'); // Tail loss probe delay of // 0.5RTT. const QuicTag kMAD0 = TAG('M', 'A', 'D', '0'); // Ignore ack delay -const QuicTag kMAD1 = TAG('M', 'A', 'D', '1'); // 25ms initial max ack delay const QuicTag kMAD2 = TAG('M', 'A', 'D', '2'); // No min TLP const QuicTag kMAD3 = TAG('M', 'A', 'D', '3'); // No min RTO -const QuicTag kMAD4 = TAG('M', 'A', 'D', '4'); // IETF style TLP -const QuicTag kMAD5 = TAG('M', 'A', 'D', '5'); // IETF style TLP with 2x mult const QuicTag k1ACK = TAG('1', 'A', 'C', 'K'); // 1 fast ack for reordering -const QuicTag kACD0 = TAG('A', 'D', 'D', '0'); // Disable ack decimation -const QuicTag kACKD = TAG('A', 'C', 'K', 'D'); // Ack decimation style acking. -const QuicTag kAKD2 = TAG('A', 'K', 'D', '2'); // Ack decimation tolerating - // out of order packets. const QuicTag kAKD3 = TAG('A', 'K', 'D', '3'); // Ack decimation style acking // with 1/8 RTT acks. -const QuicTag kAKD4 = TAG('A', 'K', 'D', '4'); // Ack decimation with 1/8 RTT - // tolerating out of order. const QuicTag kAKDU = TAG('A', 'K', 'D', 'U'); // Unlimited number of packets // received before acking -const QuicTag kACKQ = TAG('A', 'C', 'K', 'Q'); // Send an immediate ack after - // 1 RTT of not receiving. const QuicTag kAFFE = TAG('A', 'F', 'F', 'E'); // Enable client receiving // AckFrequencyFrame. const QuicTag kAFF1 = TAG('A', 'F', 'F', '1'); // Use SRTT in building @@ -202,7 +211,7 @@ const QuicTag kCONH = TAG('C', 'O', 'N', 'H'); // Conservative Handshake // Retransmissions. const QuicTag kLFAK = TAG('L', 'F', 'A', 'K'); // Don't invoke FACK on the // first ack. -const QuicTag kSTMP = TAG('S', 'T', 'M', 'P'); // Send and process timestamps +const QuicTag kSTMP = TAG('S', 'T', 'M', 'P'); // DEPRECATED const QuicTag kEACK = TAG('E', 'A', 'C', 'K'); // Bundle ack-eliciting frame // with an ACK after PTO/RTO @@ -267,6 +276,8 @@ const QuicTag kAPTO = TAG('A', 'P', 'T', 'O'); // Use 1.5 * initial RTT before const QuicTag kELDT = TAG('E', 'L', 'D', 'T'); // Enable Loss Detection Tuning +// TODO(haoyuewang) Remove RVCM option once +// --quic_remove_connection_migration_connection_option is deprecated. const QuicTag kRVCM = TAG('R', 'V', 'C', 'M'); // Validate the new address // upon client address change. @@ -342,6 +353,9 @@ const QuicTag kMTUL = TAG('M', 'T', 'U', 'L'); // Low-target MTU discovery. const QuicTag kNSLC = TAG('N', 'S', 'L', 'C'); // Always send connection close // for idle timeout. +const QuicTag kNCHP = TAG('N', 'C', 'H', 'P'); // No chaos protection. +const QuicTag kNBPE = TAG('N', 'B', 'P', 'E'); // No BoringSSL Permutes + // TLS Extensions. // Proof types (i.e. certificate types) // NOTE: although it would be silly to do so, specifying both kX509 and kX59R @@ -367,6 +381,9 @@ const QuicTag kMIUS = TAG('M', 'I', 'U', 'S'); // Max incoming unidi streams const QuicTag kADE = TAG('A', 'D', 'E', 0); // Ack Delay Exponent (IETF // QUIC ACK Frame Only). const QuicTag kIRTT = TAG('I', 'R', 'T', 'T'); // Estimated initial RTT in us. +const QuicTag kTRTT = TAG('T', 'R', 'T', 'T'); // If server receives an rtt + // from an address token, set + // it as the initial rtt. const QuicTag kSNI = TAG('S', 'N', 'I', '\0'); // Server name // indication const QuicTag kPUBS = TAG('P', 'U', 'B', 'S'); // Public key values @@ -387,6 +404,9 @@ const QuicTag kXLCT = TAG('X', 'L', 'C', 'T'); // Expected leaf certificate. const QuicTag kQLVE = TAG('Q', 'L', 'V', 'E'); // Legacy Version // Encapsulation. +const QuicTag kPDP1 = TAG('P', 'D', 'P', '1'); // Path degrading triggered + // at 1PTO. + const QuicTag kPDP2 = TAG('P', 'D', 'P', '2'); // Path degrading triggered // at 2PTO. @@ -401,9 +421,6 @@ const QuicTag kPDP5 = TAG('P', 'D', 'P', '5'); // Path degrading triggered const QuicTag kQNZ2 = TAG('Q', 'N', 'Z', '2'); // Turn off QUIC crypto 0-RTT. -const QuicTag kQNSP = TAG('Q', 'N', 'S', 'P'); // Turn off server push in - // gQUIC. - const QuicTag kMAD = TAG('M', 'A', 'D', 0); // Max Ack Delay (IETF QUIC) const QuicTag kIGNP = TAG('I', 'G', 'N', 'P'); // Do not use PING only packet @@ -414,6 +431,10 @@ const QuicTag kSRWP = TAG('S', 'R', 'W', 'P'); // Enable retransmittable on // wire PING (ROWP) on the // server side. +// Client Hints triggers. +const QuicTag kGWCH = TAG('G', 'W', 'C', 'H'); +const QuicTag kYTCH = TAG('Y', 'T', 'C', 'H'); + // Rejection tags const QuicTag kRREJ = TAG('R', 'R', 'E', 'J'); // Reasons for server sending diff --git a/gquiche/quic/core/crypto/crypto_server_test.cc b/gquiche/quic/core/crypto/crypto_server_test.cc index c35479a2..91bc7bbc 100644 --- a/gquiche/quic/core/crypto/crypto_server_test.cc +++ b/gquiche/quic/core/crypto/crypto_server_test.cc @@ -12,6 +12,7 @@ #include "absl/base/macros.h" #include "absl/strings/escaping.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "openssl/sha.h" @@ -33,7 +34,6 @@ #include "gquiche/quic/test_tools/mock_random.h" #include "gquiche/quic/test_tools/quic_crypto_server_config_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/quiche_endian.h" namespace quic { @@ -60,21 +60,19 @@ const char kOldConfigId[] = "old-config-id"; struct TestParams { friend std::ostream& operator<<(std::ostream& os, const TestParams& p) { os << " versions: " - << ParsedQuicVersionVectorToString(p.supported_versions) - << " } allow_sni_without_dots: " << p.allow_sni_without_dots; + << ParsedQuicVersionVectorToString(p.supported_versions) << " }"; return os; } // Versions supported by client and server. ParsedQuicVersionVector supported_versions; - bool allow_sni_without_dots; }; // Used by ::testing::PrintToStringParamName(). std::string PrintToString(const TestParams& p) { std::string rv = ParsedQuicVersionVectorToString(p.supported_versions); std::replace(rv.begin(), rv.end(), ',', '_'); - return absl::StrCat(rv, "_allow_sni_without_dots_", p.allow_sni_without_dots); + return rv; } // Constructs various test permutations. @@ -84,9 +82,7 @@ std::vector GetTestParams() { // Start with all versions, remove highest on each iteration. ParsedQuicVersionVector supported_versions = AllSupportedVersions(); while (!supported_versions.empty()) { - for (bool allow_sni_without_dots : {false, true}) { - params.push_back({supported_versions, allow_sni_without_dots}); - } + params.push_back({supported_versions}); supported_versions.erase(supported_versions.begin()); } @@ -110,8 +106,6 @@ class CryptoServerTest : public QuicTestWithParam { signed_config_(new QuicSignedServerConfig), chlo_packet_size_(kDefaultMaxPacketSize) { supported_versions_ = GetParam().supported_versions; - SetQuicReloadableFlag(quic_and_tls_allow_sni_without_dots, - GetParam().allow_sni_without_dots); config_.set_enable_serving_sct(true); client_version_ = supported_versions_.front(); @@ -280,7 +274,7 @@ class CryptoServerTest : public QuicTestWithParam { } else { ASSERT_NE(error, QUIC_NO_ERROR) << "Message didn't fail: " << result_->client_hello.DebugString(); - EXPECT_TRUE(error_details.find(error_substr_) != std::string::npos) + EXPECT_TRUE(absl::StrContains(error_details, error_substr_)) << error_substr_ << " not in " << error_details; } if (message != nullptr) { @@ -388,9 +382,6 @@ TEST_P(CryptoServerTest, BadSNI) { "127.0.0.1", "ffee::1", }; - if (!GetParam().allow_sni_without_dots) { - badSNIs.push_back("foo"); - } // clang-format on for (const std::string& bad_sni : badSNIs) { @@ -403,12 +394,11 @@ TEST_P(CryptoServerTest, BadSNI) { CheckRejectReasons(kRejectReasons, ABSL_ARRAYSIZE(kRejectReasons)); } - if (GetParam().allow_sni_without_dots) { - CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( - {{"PDMD", "X509"}, {"SNI", "foo"}, {"VER\0", client_version_string_}}, - kClientHelloMinimumSize); - ShouldSucceed(msg); - } + // Check that SNIs without dots are allowed + CryptoHandshakeMessage msg = crypto_test_utils::CreateCHLO( + {{"PDMD", "X509"}, {"SNI", "foo"}, {"VER\0", client_version_string_}}, + kClientHelloMinimumSize); + ShouldSucceed(msg); } TEST_P(CryptoServerTest, DefaultCert) { diff --git a/gquiche/quic/core/crypto/crypto_utils.cc b/gquiche/quic/core/crypto/crypto_utils.cc index 9a068e08..9a4f7cb4 100644 --- a/gquiche/quic/core/crypto/crypto_utils.cc +++ b/gquiche/quic/core/crypto/crypto_utils.cc @@ -146,10 +146,6 @@ const uint8_t kRFCv1InitialSalt[] = {0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, const uint8_t kQ050Salt[] = {0x50, 0x45, 0x74, 0xef, 0xd0, 0x66, 0xfe, 0x2f, 0x9d, 0x94, 0x5c, 0xfc, 0xdb, 0xd3, 0xa7, 0xf0, 0xd3, 0xb5, 0x6b, 0x45}; -// Salt to use for initial obfuscators in version T051. -const uint8_t kT051Salt[] = {0x7a, 0x4e, 0xde, 0xf4, 0xe7, 0xcc, 0xee, - 0x5f, 0xa4, 0x50, 0x6c, 0x19, 0x12, 0x4f, - 0xc8, 0xcc, 0xda, 0x6e, 0x03, 0x3d}; // Salt to use for initial obfuscators in // ParsedQuicVersion::ReservedForNegotiation(). const uint8_t kReservedForNegotiationSalt[] = { @@ -158,7 +154,7 @@ const uint8_t kReservedForNegotiationSalt[] = { const uint8_t* InitialSaltForVersion(const ParsedQuicVersion& version, size_t* out_len) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync with initial encryption salts"); if (version == ParsedQuicVersion::RFCv1()) { *out_len = ABSL_ARRAYSIZE(kRFCv1InitialSalt); @@ -166,9 +162,6 @@ const uint8_t* InitialSaltForVersion(const ParsedQuicVersion& version, } else if (version == ParsedQuicVersion::Draft29()) { *out_len = ABSL_ARRAYSIZE(kDraft29InitialSalt); return kDraft29InitialSalt; - } else if (version == ParsedQuicVersion::T051()) { - *out_len = ABSL_ARRAYSIZE(kT051Salt); - return kT051Salt; } else if (version == ParsedQuicVersion::Q050()) { *out_len = ABSL_ARRAYSIZE(kQ050Salt); return kQ050Salt; @@ -186,6 +179,8 @@ const char kPreSharedKeyLabel[] = "QUIC PSK"; // Retry Integrity Protection Keys and Nonces. // https://tools.ietf.org/html/draft-ietf-quic-tls-29#section-5.8 +// When introducing a new Google version, generate a new key by running +// `openssl rand -hex 16`. const uint8_t kDraft29RetryIntegrityKey[] = {0xcc, 0xce, 0x18, 0x7e, 0xd0, 0x9a, 0x09, 0xd0, 0x57, 0x28, 0x15, 0x5a, 0x6c, 0xb9, 0x6b, 0xe1}; @@ -196,20 +191,12 @@ const uint8_t kRFCv1RetryIntegrityKey[] = {0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0xe3, 0x68, 0xc8, 0x4e}; const uint8_t kRFCv1RetryIntegrityNonce[] = { 0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb}; - -// Keys used by Google versions of QUIC. When introducing a new version, -// generate a new key by running `openssl rand -hex 16`. -const uint8_t kT051RetryIntegrityKey[] = {0x2e, 0xb9, 0x61, 0xa6, 0x79, 0x56, - 0xf8, 0x79, 0x53, 0x14, 0xda, 0xfb, - 0x2e, 0xbc, 0x83, 0xd7}; // Retry integrity key used by ParsedQuicVersion::ReservedForNegotiation(). const uint8_t kReservedForNegotiationRetryIntegrityKey[] = { 0xf2, 0xcd, 0x8f, 0xe0, 0x36, 0xd0, 0x25, 0x35, 0x03, 0xe6, 0x7c, 0x7b, 0xd2, 0x44, 0xca, 0xd9}; -// Nonces used by Google versions of QUIC. When introducing a new version, -// generate a new nonce by running `openssl rand -hex 12`. -const uint8_t kT051RetryIntegrityNonce[] = {0xb5, 0x0e, 0x4e, 0x53, 0x4c, 0xfc, - 0x0b, 0xbb, 0x85, 0xf2, 0xf9, 0xca}; +// When introducing a new Google version, generate a new nonce by running +// `openssl rand -hex 12`. // Retry integrity nonce used by ParsedQuicVersion::ReservedForNegotiation(). const uint8_t kReservedForNegotiationRetryIntegrityNonce[] = { 0x35, 0x9f, 0x16, 0xd1, 0xed, 0x80, 0x90, 0x8e, 0xec, 0x85, 0xc4, 0xd6}; @@ -217,7 +204,7 @@ const uint8_t kReservedForNegotiationRetryIntegrityNonce[] = { bool RetryIntegrityKeysForVersion(const ParsedQuicVersion& version, absl::string_view* key, absl::string_view* nonce) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync with retry integrity keys"); if (!version.UsesTls()) { QUIC_BUG(quic_bug_10699_2) @@ -240,14 +227,6 @@ bool RetryIntegrityKeysForVersion(const ParsedQuicVersion& version, reinterpret_cast(kDraft29RetryIntegrityNonce), ABSL_ARRAYSIZE(kDraft29RetryIntegrityNonce)); return true; - } else if (version == ParsedQuicVersion::T051()) { - *key = - absl::string_view(reinterpret_cast(kT051RetryIntegrityKey), - ABSL_ARRAYSIZE(kT051RetryIntegrityKey)); - *nonce = absl::string_view( - reinterpret_cast(kT051RetryIntegrityNonce), - ABSL_ARRAYSIZE(kT051RetryIntegrityNonce)); - return true; } else if (version == ParsedQuicVersion::ReservedForNegotiation()) { *key = absl::string_view( reinterpret_cast(kReservedForNegotiationRetryIntegrityKey), @@ -541,35 +520,6 @@ bool CryptoUtils::DeriveKeys(const ParsedQuicVersion& version, return true; } -// static -bool CryptoUtils::ExportKeyingMaterial(absl::string_view subkey_secret, - absl::string_view label, - absl::string_view context, - size_t result_len, - std::string* result) { - for (size_t i = 0; i < label.length(); i++) { - if (label[i] == '\0') { - QUIC_LOG(ERROR) << "ExportKeyingMaterial label may not contain NULs"; - return false; - } - } - // Create HKDF info input: null-terminated label + length-prefixed context - if (context.length() >= std::numeric_limits::max()) { - QUIC_LOG(ERROR) << "Context value longer than 2^32"; - return false; - } - uint32_t context_length = static_cast(context.length()); - std::string info = std::string(label); - info.push_back('\0'); - info.append(reinterpret_cast(&context_length), sizeof(context_length)); - info.append(context.data(), context.length()); - - QuicHKDF hkdf(subkey_secret, absl::string_view() /* no salt */, info, - result_len, 0 /* no fixed IV */, 0 /* no subkey secret */); - *result = std::string(hkdf.client_write_key()); - return true; -} - // static uint64_t CryptoUtils::ComputeLeafCertHash(absl::string_view cert) { return QuicUtils::FNV1a_64_Hash(cert); @@ -670,6 +620,62 @@ QuicErrorCode CryptoUtils::ValidateClientHelloVersion( return QUIC_NO_ERROR; } +// static +bool CryptoUtils::ValidateChosenVersion( + const QuicVersionLabel& version_information_chosen_version, + const ParsedQuicVersion& session_version, std::string* error_details) { + if (version_information_chosen_version != + CreateQuicVersionLabel(session_version)) { + *error_details = absl::StrCat( + "Detected version mismatch: version_information contained ", + QuicVersionLabelToString(version_information_chosen_version), + " instead of ", ParsedQuicVersionToString(session_version)); + return false; + } + return true; +} + +// static +bool CryptoUtils::ValidateServerVersions( + const QuicVersionLabelVector& version_information_other_versions, + const ParsedQuicVersion& session_version, + const ParsedQuicVersionVector& client_original_supported_versions, + std::string* error_details) { + if (client_original_supported_versions.empty()) { + // We did not receive a version negotiation packet. + return true; + } + // Parse the server's other versions. + ParsedQuicVersionVector parsed_other_versions = + ParseQuicVersionLabelVector(version_information_other_versions); + // Find the first version that we originally supported that is listed in the + // server's other versions. + ParsedQuicVersion expected_version = ParsedQuicVersion::Unsupported(); + for (const ParsedQuicVersion& client_version : + client_original_supported_versions) { + if (std::find(parsed_other_versions.begin(), parsed_other_versions.end(), + client_version) != parsed_other_versions.end()) { + expected_version = client_version; + break; + } + } + if (expected_version != session_version) { + *error_details = absl::StrCat( + "Downgrade attack detected: used ", + ParsedQuicVersionToString(session_version), " but ServerVersions(", + version_information_other_versions.size(), ")[", + QuicVersionLabelVectorToString(version_information_other_versions, ",", + 30), + "] ClientOriginalVersions(", client_original_supported_versions.size(), + ")[", + ParsedQuicVersionVectorToString(client_original_supported_versions, ",", + 30), + "]"); + return false; + } + return true; +} + #define RETURN_STRING_LITERAL(x) \ case x: \ return #x @@ -713,35 +719,15 @@ const char* CryptoUtils::HandshakeFailureReasonToString( return "INVALID_HANDSHAKE_FAILURE_REASON"; } +#undef RETURN_STRING_LITERAL // undef for jumbo builds + // static std::string CryptoUtils::EarlyDataReasonToString( ssl_early_data_reason_t reason) { -#if BORINGSSL_API_VERSION >= 12 const char* reason_string = SSL_early_data_reason_string(reason); if (reason_string != nullptr) { return std::string("ssl_early_data_") + reason_string; } -#else - // TODO(davidben): Remove this logic once - // https://boringssl-review.googlesource.com/c/boringssl/+/43724 has landed in - // downstream repositories. - switch (reason) { - RETURN_STRING_LITERAL(ssl_early_data_unknown); - RETURN_STRING_LITERAL(ssl_early_data_disabled); - RETURN_STRING_LITERAL(ssl_early_data_accepted); - RETURN_STRING_LITERAL(ssl_early_data_protocol_version); - RETURN_STRING_LITERAL(ssl_early_data_peer_declined); - RETURN_STRING_LITERAL(ssl_early_data_no_session_offered); - RETURN_STRING_LITERAL(ssl_early_data_session_not_resumed); - RETURN_STRING_LITERAL(ssl_early_data_unsupported_for_session); - RETURN_STRING_LITERAL(ssl_early_data_hello_retry_request); - RETURN_STRING_LITERAL(ssl_early_data_alpn_mismatch); - RETURN_STRING_LITERAL(ssl_early_data_channel_id); - RETURN_STRING_LITERAL(ssl_early_data_token_binding); - RETURN_STRING_LITERAL(ssl_early_data_ticket_age_skew); - RETURN_STRING_LITERAL(ssl_early_data_quic_parameter_mismatch); - } -#endif QUIC_BUG_IF(quic_bug_12871_3, reason < 0 || reason > ssl_early_data_reason_max_value) << "Unknown ssl_early_data_reason_t " << reason; @@ -761,5 +747,20 @@ std::string CryptoUtils::HashHandshakeMessage( return output; } -#undef RETURN_STRING_LITERAL // undef for jumbo builds +// static +bool CryptoUtils::GetSSLCapabilities(const SSL* ssl, + bssl::UniquePtr* capabilities, + size_t* capabilities_len) { + uint8_t* buffer; + CBB cbb; + + if (!CBB_init(&cbb, 128) || !SSL_serialize_capabilities(ssl, &cbb) || + !CBB_finish(&cbb, &buffer, capabilities_len)) { + return false; + } + + *capabilities = bssl::UniquePtr(buffer); + return true; +} + } // namespace quic diff --git a/gquiche/quic/core/crypto/crypto_utils.h b/gquiche/quic/core/crypto/crypto_utils.h index c567b03c..a58d31a5 100644 --- a/gquiche/quic/core/crypto/crypto_utils.h +++ b/gquiche/quic/core/crypto/crypto_utils.h @@ -166,16 +166,6 @@ class QUIC_EXPORT_PRIVATE CryptoUtils { CrypterPair* crypters, std::string* subkey_secret); - // Performs key extraction to derive a new secret of |result_len| bytes - // dependent on |subkey_secret|, |label|, and |context|. Returns false if the - // parameters are invalid (e.g. |label| contains null bytes); returns true on - // success. - static bool ExportKeyingMaterial(absl::string_view subkey_secret, - absl::string_view label, - absl::string_view context, - size_t result_len, - std::string* result); - // Computes the FNV-1a hash of the provided DER-encoded cert for use in the // XLCT tag. static uint64_t ComputeLeafCertHash(absl::string_view cert); @@ -227,6 +217,27 @@ class QUIC_EXPORT_PRIVATE CryptoUtils { const ParsedQuicVersionVector& supported_versions, std::string* error_details); + // Validates that the chosen version from the version_information matches the + // version from the session. Returns true if they match, otherwise returns + // false and fills in |error_details|. + static bool ValidateChosenVersion( + const QuicVersionLabel& version_information_chosen_version, + const ParsedQuicVersion& session_version, std::string* error_details); + + // Validates that there was no downgrade attack involving a version + // negotiation packet. This verifies that if the client was initially + // configured with |client_original_supported_versions| and it had received a + // version negotiation packet with |version_information_other_versions|, then + // it would have selected |session_version|. Returns true if they match (or if + // |client_original_supported_versions| is empty indicating no version + // negotiation packet was received), otherwise returns + // false and fills in |error_details|. + static bool ValidateServerVersions( + const QuicVersionLabelVector& version_information_other_versions, + const ParsedQuicVersion& session_version, + const ParsedQuicVersionVector& client_original_supported_versions, + std::string* error_details); + // Returns the name of the HandshakeFailureReason as a char* static const char* HandshakeFailureReasonToString( HandshakeFailureReason reason); @@ -237,6 +248,11 @@ class QUIC_EXPORT_PRIVATE CryptoUtils { // Returns a hash of the serialized |message|. static std::string HashHandshakeMessage(const CryptoHandshakeMessage& message, Perspective perspective); + + // Wraps SSL_serialize_capabilities. Return nullptr if failed. + static bool GetSSLCapabilities(const SSL* ssl, + bssl::UniquePtr* capabilities, + size_t* capabilities_len); }; } // namespace quic diff --git a/gquiche/quic/core/crypto/crypto_utils_test.cc b/gquiche/quic/core/crypto/crypto_utils_test.cc index 80a3f7eb..58a9063b 100644 --- a/gquiche/quic/core/crypto/crypto_utils_test.cc +++ b/gquiche/quic/core/crypto/crypto_utils_test.cc @@ -11,7 +11,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace quic { @@ -20,63 +19,6 @@ namespace { class CryptoUtilsTest : public QuicTest {}; -TEST_F(CryptoUtilsTest, TestExportKeyingMaterial) { - const struct TestVector { - // Input (strings of hexadecimal digits): - const char* subkey_secret; - const char* label; - const char* context; - size_t result_len; - - // Expected output (string of hexadecimal digits): - const char* expected; // Null if it should fail. - } test_vector[] = { - // Try a typical input - {"4823c1189ecc40fce888fbb4cf9ae6254f19ba12e6d9af54788f195a6f509ca3", - "e934f78d7a71dd85420fceeb8cea0317", - "b8d766b5d3c8aba0009c7ed3de553eba53b4de1030ea91383dcdf724cd8b7217", 32, - "a9979da0d5f1c1387d7cbe68f5c4163ddb445a03c4ad6ee72cb49d56726d679e"}, - // Don't let the label contain nulls - {"14fe51e082ffee7d1b4d8d4ab41f8c55", "3132333435363700", - "58585858585858585858585858585858", 16, nullptr}, - // Make sure nulls in the context are fine - {"d862c2e36b0a42f7827c67ebc8d44df7", "7a5b95e4e8378123", - "4142434445464700", 16, "12d418c6d0738a2e4d85b2d0170f76e1"}, - // ... and give a different result than without - {"d862c2e36b0a42f7827c67ebc8d44df7", "7a5b95e4e8378123", "41424344454647", - 16, "abfa1c479a6e3ffb98a11dee7d196408"}, - // Try weird lengths - {"d0ec8a34f6cc9a8c96", "49711798cc6251", - "933d4a2f30d22f089cfba842791116adc121e0", 23, - "c9a46ed0757bd1812f1f21b4d41e62125fec8364a21db7"}, - }; - - for (size_t i = 0; i < ABSL_ARRAYSIZE(test_vector); i++) { - // Decode the test vector. - std::string subkey_secret = - absl::HexStringToBytes(test_vector[i].subkey_secret); - std::string label = absl::HexStringToBytes(test_vector[i].label); - std::string context = absl::HexStringToBytes(test_vector[i].context); - size_t result_len = test_vector[i].result_len; - bool expect_ok = test_vector[i].expected != nullptr; - std::string expected; - if (expect_ok) { - expected = absl::HexStringToBytes(test_vector[i].expected); - } - - std::string result; - bool ok = CryptoUtils::ExportKeyingMaterial(subkey_secret, label, context, - result_len, &result); - EXPECT_EQ(expect_ok, ok); - if (expect_ok) { - EXPECT_EQ(result_len, result.length()); - quiche::test::CompareCharArraysWithHexError( - "HKDF output", result.data(), result.length(), expected.data(), - expected.length()); - } - } -} - TEST_F(CryptoUtilsTest, HandshakeFailureReasonToString) { EXPECT_STREQ("HANDSHAKE_OK", CryptoUtils::HandshakeFailureReasonToString(HANDSHAKE_OK)); @@ -166,6 +108,63 @@ TEST_F(CryptoUtilsTest, AuthTagLengths) { } } +TEST_F(CryptoUtilsTest, ValidateChosenVersion) { + for (const ParsedQuicVersion& v1 : AllSupportedVersions()) { + for (const ParsedQuicVersion& v2 : AllSupportedVersions()) { + std::string error_details; + bool success = CryptoUtils::ValidateChosenVersion( + CreateQuicVersionLabel(v1), v2, &error_details); + EXPECT_EQ(success, v1 == v2); + EXPECT_EQ(success, error_details.empty()); + } + } +} + +TEST_F(CryptoUtilsTest, ValidateServerVersionsNoVersionNegotiation) { + QuicVersionLabelVector version_information_other_versions; + ParsedQuicVersionVector client_original_supported_versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + std::string error_details; + EXPECT_TRUE(CryptoUtils::ValidateServerVersions( + version_information_other_versions, version, + client_original_supported_versions, &error_details)); + EXPECT_TRUE(error_details.empty()); + } +} + +TEST_F(CryptoUtilsTest, ValidateServerVersionsWithVersionNegotiation) { + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + QuicVersionLabelVector version_information_other_versions{ + CreateQuicVersionLabel(version)}; + ParsedQuicVersionVector client_original_supported_versions{ + ParsedQuicVersion::ReservedForNegotiation(), version}; + std::string error_details; + EXPECT_TRUE(CryptoUtils::ValidateServerVersions( + version_information_other_versions, version, + client_original_supported_versions, &error_details)); + EXPECT_TRUE(error_details.empty()); + } +} + +TEST_F(CryptoUtilsTest, ValidateServerVersionsWithDowngrade) { + if (AllSupportedVersions().size() <= 1) { + // We are not vulnerable to downgrade if we only support one version. + return; + } + ParsedQuicVersion client_version = AllSupportedVersions().front(); + ParsedQuicVersion server_version = AllSupportedVersions().back(); + ASSERT_NE(client_version, server_version); + QuicVersionLabelVector version_information_other_versions{ + CreateQuicVersionLabel(client_version)}; + ParsedQuicVersionVector client_original_supported_versions{ + ParsedQuicVersion::ReservedForNegotiation(), server_version}; + std::string error_details; + EXPECT_FALSE(CryptoUtils::ValidateServerVersions( + version_information_other_versions, server_version, + client_original_supported_versions, &error_details)); + EXPECT_FALSE(error_details.empty()); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/crypto/key_exchange.cc b/gquiche/quic/core/crypto/key_exchange.cc index 0aedd727..887f4eb9 100644 --- a/gquiche/quic/core/crypto/key_exchange.cc +++ b/gquiche/quic/core/crypto/key_exchange.cc @@ -31,10 +31,8 @@ std::unique_ptr CreateLocalSynchronousKeyExchange( switch (type) { case kC255: return Curve25519KeyExchange::New(rand); - break; case kP256: return P256KeyExchange::New(); - break; default: QUIC_BUG(quic_bug_10712_2) << "Unknown key exchange method: " << QuicTagToString(type); diff --git a/gquiche/quic/core/crypto/proof_source.cc b/gquiche/quic/core/crypto/proof_source.cc index 4ee2bda7..5250ff9d 100644 --- a/gquiche/quic/core/crypto/proof_source.cc +++ b/gquiche/quic/core/crypto/proof_source.cc @@ -2,9 +2,11 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "gquiche/quic/core/crypto/proof_source.h" + #include -#include "gquiche/quic/core/crypto/proof_source.h" +#include "gquiche/quic/platform/api/quic_bug_tracker.h" namespace quic { @@ -30,4 +32,28 @@ CryptoBuffers ProofSource::Chain::ToCryptoBuffers() const { return crypto_buffers; } +bool ValidateCertAndKey( + const QuicReferenceCountedPointer& chain, + const CertificatePrivateKey& key) { + if (chain.get() == nullptr || chain->certs.empty()) { + QUIC_BUG(quic_proof_source_empty_chain) << "Certificate chain is empty"; + return false; + } + + std::unique_ptr leaf = + CertificateView::ParseSingleCertificate(chain->certs[0]); + if (leaf == nullptr) { + QUIC_BUG(quic_proof_source_unparsable_leaf_cert) + << "Unabled to parse leaf certificate"; + return false; + } + + if (!key.MatchesPublicKey(*leaf)) { + QUIC_BUG(quic_proof_source_key_mismatch) + << "Private key does not match the leaf certificate"; + return false; + } + return true; +} + } // namespace quic diff --git a/gquiche/quic/core/crypto/proof_source.h b/gquiche/quic/core/crypto/proof_source.h index dfc0fee6..83eaca32 100644 --- a/gquiche/quic/core/crypto/proof_source.h +++ b/gquiche/quic/core/crypto/proof_source.h @@ -11,6 +11,7 @@ #include "absl/strings/string_view.h" #include "openssl/ssl.h" +#include "gquiche/quic/core/crypto/certificate_view.h" #include "gquiche/quic/core/crypto/quic_crypto_proof.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_export.h" @@ -145,10 +146,13 @@ class QUIC_EXPORT_PRIVATE ProofSource { std::unique_ptr callback) = 0; // Returns the certificate chain for |hostname| in leaf-first order. + // + // Sets *cert_matched_sni to true if the certificate matched the given + // hostname, false if a default cert not matching the hostname was used. virtual QuicReferenceCountedPointer GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) = 0; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) = 0; // Computes a signature using the private key of the certificate for // |hostname|. The value in |in| is signed using the algorithm specified by @@ -166,6 +170,15 @@ class QUIC_EXPORT_PRIVATE ProofSource { absl::string_view in, std::unique_ptr callback) = 0; + // Return the list of TLS signature algorithms that is acceptable by the + // ComputeTlsSignature method. If the entire BoringSSL's default list of + // supported signature algorithms are acceptable, return an empty list. + // + // If returns a non-empty list, ComputeTlsSignature will only be called with a + // algorithm in the list. + virtual absl::InlinedVector SupportedTlsSignatureAlgorithms() + const = 0; + class QUIC_EXPORT_PRIVATE DecryptCallback { public: DecryptCallback() = default; @@ -197,7 +210,12 @@ class QUIC_EXPORT_PRIVATE ProofSource { // returns the encrypted ticket. The resulting value must not be larger than // MaxOverhead bytes larger than |in|. If encryption fails, this method // returns an empty vector. - virtual std::vector Encrypt(absl::string_view in) = 0; + // + // If |encryption_key| is nonempty, this method should use it for minting + // TLS resumption tickets. If it is empty, this method may use an + // internally cached encryption key, if available. + virtual std::vector Encrypt(absl::string_view in, + absl::string_view encryption_key) = 0; // Decrypt takes an encrypted ticket |in|, decrypts it, and calls // |callback->Run| with the decrypted ticket, which must not be larger than @@ -229,12 +247,21 @@ class QUIC_EXPORT_PRIVATE ProofSourceHandleCallback { // whether it is completed before ProofSourceHandle::SelectCertificate // returned. // |chain| the certificate chain in leaf-first order. + // |handshake_hints| (optional) handshake hints that can be used by + // SSL_set_handshake_hints. + // |ticket_encryption_key| (optional) encryption key to be used for minting + // TLS resumption tickets. + // |cert_matched_sni| is true if the certificate matched the SNI hostname, + // false if a non-matching default cert was used. + // |delayed_ssl_config| contains SSL configs to be applied on the SSL object. // // When called asynchronously(is_sync=false), this method will be responsible // to continue the handshake from where it left off. - virtual void OnSelectCertificateDone(bool ok, - bool is_sync, - const ProofSource::Chain* chain) = 0; + virtual void OnSelectCertificateDone( + bool ok, bool is_sync, const ProofSource::Chain* chain, + absl::string_view handshake_hints, + absl::string_view ticket_encryption_key, bool cert_matched_sni, + QuicDelayedSSLConfig delayed_ssl_config) = 0; // Called when a ProofSourceHandle::ComputeSignature operation completes. virtual void OnComputeSignatureDone( @@ -242,6 +269,10 @@ class QUIC_EXPORT_PRIVATE ProofSourceHandleCallback { bool is_sync, std::string signature, std::unique_ptr details) = 0; + + // Return true iff ProofSourceHandle::ComputeSignature won't be called later. + // The handle can use this function to release resources promptly. + virtual bool WillNotCallComputeSignature() const = 0; }; // ProofSourceHandle is an interface by which a TlsServerHandshaker can obtain @@ -261,13 +292,17 @@ class QUIC_EXPORT_PRIVATE ProofSourceHandle { public: virtual ~ProofSourceHandle() = default; - // Cancel the pending operation, if any. - // Once called, any completion method on |callback()| won't be invoked. - virtual void CancelPendingOperation() = 0; + // Close the handle. Cancel the pending operation, if any. + // Once called, any completion method on |callback()| won't be invoked, and + // future SelectCertificate and ComputeSignature calls should return failure. + virtual void CloseHandle() = 0; // Starts a select certificate operation. If the operation is not cancelled // when it completes, callback()->OnSelectCertificateDone will be invoked. // + // server_address and client_address should be normalized by the caller before + // sending down to this function. + // // If the operation is handled synchronously: // - QUIC_SUCCESS or QUIC_FAILURE will be returned. // - callback()->OnSelectCertificateDone should be invoked before the function @@ -280,11 +315,14 @@ class QUIC_EXPORT_PRIVATE ProofSourceHandle { virtual QuicAsyncStatus SelectCertificate( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, + absl::string_view ssl_capabilities, const std::string& hostname, absl::string_view client_hello, const std::string& alpn, + absl::optional alps, const std::vector& quic_transport_params, - const absl::optional>& early_data_context) = 0; + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) = 0; // Starts a compute signature operation. If the operation is not cancelled // when it completes, callback()->OnComputeSignatureDone will be invoked. @@ -306,6 +344,12 @@ class QUIC_EXPORT_PRIVATE ProofSourceHandle { friend class test::FakeProofSourceHandle; }; +// Returns true if |chain| contains a parsable DER-encoded X.509 leaf cert and +// it matches with |key|. +QUIC_EXPORT_PRIVATE bool ValidateCertAndKey( + const QuicReferenceCountedPointer& chain, + const CertificatePrivateKey& key); + } // namespace quic #endif // QUICHE_QUIC_CORE_CRYPTO_PROOF_SOURCE_H_ diff --git a/gquiche/quic/core/crypto/proof_source_x509.cc b/gquiche/quic/core/crypto/proof_source_x509.cc index ec7e8b58..3724b3a9 100644 --- a/gquiche/quic/core/crypto/proof_source_x509.cc +++ b/gquiche/quic/core/crypto/proof_source_x509.cc @@ -53,7 +53,7 @@ void ProofSourceX509::GetProof( return; } - Certificate* certificate = GetCertificate(hostname); + Certificate* certificate = GetCertificate(hostname, &proof.cert_matched_sni); proof.signature = certificate->key.Sign(absl::string_view(payload.get(), payload_size), SSL_SIGN_RSA_PSS_RSAE_SHA256); @@ -63,9 +63,9 @@ void ProofSourceX509::GetProof( QuicReferenceCountedPointer ProofSourceX509::GetCertChain( const QuicSocketAddress& /*server_address*/, - const QuicSocketAddress& /*client_address*/, - const std::string& hostname) { - return GetCertificate(hostname)->chain; + const QuicSocketAddress& /*client_address*/, const std::string& hostname, + bool* cert_matched_sni) { + return GetCertificate(hostname, cert_matched_sni)->chain; } void ProofSourceX509::ComputeTlsSignature( @@ -75,11 +75,19 @@ void ProofSourceX509::ComputeTlsSignature( uint16_t signature_algorithm, absl::string_view in, std::unique_ptr callback) { - std::string signature = - GetCertificate(hostname)->key.Sign(in, signature_algorithm); + bool cert_matched_sni; + std::string signature = GetCertificate(hostname, &cert_matched_sni) + ->key.Sign(in, signature_algorithm); callback->Run(/*ok=*/!signature.empty(), signature, nullptr); } +absl::InlinedVector +ProofSourceX509::SupportedTlsSignatureAlgorithms() const { + // Let ComputeTlsSignature() report an error if a bad signature algorithm is + // requested. + return {}; +} + ProofSource::TicketCrypter* ProofSourceX509::GetTicketCrypter() { return nullptr; } @@ -118,9 +126,10 @@ bool ProofSourceX509::AddCertificateChain( } ProofSourceX509::Certificate* ProofSourceX509::GetCertificate( - const std::string& hostname) const { + const std::string& hostname, bool* cert_matched_sni) const { auto it = certificate_map_.find(hostname); if (it != certificate_map_.end()) { + *cert_matched_sni = true; return it->second; } auto dot_pos = hostname.find('.'); @@ -128,9 +137,11 @@ ProofSourceX509::Certificate* ProofSourceX509::GetCertificate( std::string wildcard = absl::StrCat("*", hostname.substr(dot_pos)); it = certificate_map_.find(wildcard); if (it != certificate_map_.end()) { + *cert_matched_sni = true; return it->second; } } + *cert_matched_sni = false; return default_certificate_; } diff --git a/gquiche/quic/core/crypto/proof_source_x509.h b/gquiche/quic/core/crypto/proof_source_x509.h index 9e92ee86..004c005a 100644 --- a/gquiche/quic/core/crypto/proof_source_x509.h +++ b/gquiche/quic/core/crypto/proof_source_x509.h @@ -37,15 +37,15 @@ class QUIC_EXPORT_PRIVATE ProofSourceX509 : public ProofSource { std::unique_ptr callback) override; QuicReferenceCountedPointer GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; void ComputeTlsSignature( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname, - uint16_t signature_algorithm, - absl::string_view in, + const QuicSocketAddress& client_address, const std::string& hostname, + uint16_t signature_algorithm, absl::string_view in, std::unique_ptr callback) override; + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override; TicketCrypter* GetTicketCrypter() override; // Adds a certificate chain to the verifier. Returns false if the chain is @@ -65,7 +65,8 @@ class QUIC_EXPORT_PRIVATE ProofSourceX509 : public ProofSource { // Looks up certficiate for hostname, returns the default if no certificate is // found. - Certificate* GetCertificate(const std::string& hostname) const; + Certificate* GetCertificate(const std::string& hostname, + bool* cert_matched_sni) const; std::forward_list certificates_; Certificate* default_certificate_; diff --git a/gquiche/quic/core/crypto/proof_source_x509_test.cc b/gquiche/quic/core/crypto/proof_source_x509_test.cc index 4fc362be..beb66e86 100644 --- a/gquiche/quic/core/crypto/proof_source_x509_test.cc +++ b/gquiche/quic/core/crypto/proof_source_x509_test.cc @@ -58,8 +58,7 @@ TEST_F(ProofSourceX509Test, AddCertificateKeyMismatch) { ProofSourceX509::Create(test_chain_, std::move(*test_key_)); ASSERT_TRUE(proof_source != nullptr); test_key_ = CertificatePrivateKey::LoadFromDer(kTestCertificatePrivateKey); - bool result; - EXPECT_QUIC_BUG(result = proof_source->AddCertificateChain( + EXPECT_QUIC_BUG((void)proof_source->AddCertificateChain( wildcard_chain_, std::move(*test_key_)), "Private key does not match"); } @@ -72,40 +71,47 @@ TEST_F(ProofSourceX509Test, CertificateSelection) { std::move(*wildcard_key_))); // Default certificate. + bool cert_matched_sni; EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "unknown.test") + "unknown.test", &cert_matched_sni) ->certs[0], kTestCertificate); + EXPECT_FALSE(cert_matched_sni); // mail.example.org is explicitly a SubjectAltName in kTestCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "mail.example.org") + "mail.example.org", &cert_matched_sni) ->certs[0], kTestCertificate); + EXPECT_TRUE(cert_matched_sni); // www.foo.test is in kWildcardCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "www.foo.test") + "www.foo.test", &cert_matched_sni) ->certs[0], kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); // *.wildcard.test is in kWildcardCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "www.wildcard.test") + "www.wildcard.test", &cert_matched_sni) ->certs[0], kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "etc.wildcard.test") + "etc.wildcard.test", &cert_matched_sni) ->certs[0], kWildcardCertificate); + EXPECT_TRUE(cert_matched_sni); // wildcard.test itself is not in kWildcardCertificate. EXPECT_EQ(proof_source ->GetCertChain(QuicSocketAddress(), QuicSocketAddress(), - "wildcard.test") + "wildcard.test", &cert_matched_sni) ->certs[0], kTestCertificate); + EXPECT_FALSE(cert_matched_sni); } TEST_F(ProofSourceX509Test, TlsSignature) { diff --git a/gquiche/quic/core/crypto/quic_client_session_cache.cc b/gquiche/quic/core/crypto/quic_client_session_cache.cc new file mode 100644 index 00000000..626bf003 --- /dev/null +++ b/gquiche/quic/core/crypto/quic_client_session_cache.cc @@ -0,0 +1,174 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/crypto/quic_client_session_cache.h" + +#include "gquiche/quic/core/quic_clock.h" + +namespace quic { + +namespace { + +const size_t kDefaultMaxEntries = 1024; +// Returns false if the SSL |session| doesn't exist or it is expired at |now|. +bool IsValid(SSL_SESSION* session, uint64_t now) { + if (!session) return false; + + // now_u64 may be slightly behind because of differences in how + // time is calculated at this layer versus BoringSSL. + // Add a second of wiggle room to account for this. + return !(now + 1 < SSL_SESSION_get_time(session) || + now >= SSL_SESSION_get_time(session) + + SSL_SESSION_get_timeout(session)); +} + +bool DoApplicationStatesMatch(const ApplicationState* state, + ApplicationState* other) { + if ((state && !other) || (!state && other)) return false; + if ((!state && !other) || *state == *other) return true; + return false; +} + +} // namespace + +QuicClientSessionCache::QuicClientSessionCache() + : QuicClientSessionCache(kDefaultMaxEntries) {} + +QuicClientSessionCache::QuicClientSessionCache(size_t max_entries) + : cache_(max_entries) {} + +QuicClientSessionCache::~QuicClientSessionCache() { Clear(); } + +void QuicClientSessionCache::Insert(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) { + QUICHE_DCHECK(session) << "TLS session is not inserted into client cache."; + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) { + CreateAndInsertEntry(server_id, std::move(session), params, + application_state); + return; + } + + QUICHE_DCHECK(iter->second->params); + // The states are both the same, so only need to insert sessions. + if (params == *iter->second->params && + DoApplicationStatesMatch(application_state, + iter->second->application_state.get())) { + iter->second->PushSession(std::move(session)); + return; + } + // Erase the existing entry because this Insert call must come from a + // different QUIC session. + cache_.Erase(iter); + CreateAndInsertEntry(server_id, std::move(session), params, + application_state); +} + +std::unique_ptr QuicClientSessionCache::Lookup( + const QuicServerId& server_id, QuicWallTime now, const SSL_CTX* /*ctx*/) { + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) return nullptr; + + if (!IsValid(iter->second->PeekSession(), now.ToUNIXSeconds())) { + QUIC_DLOG(INFO) << "TLS Session expired for host:" << server_id.host(); + cache_.Erase(iter); + return nullptr; + } + auto state = std::make_unique(); + state->tls_session = iter->second->PopSession(); + if (iter->second->params != nullptr) { + state->transport_params = + std::make_unique(*iter->second->params); + } + if (iter->second->application_state != nullptr) { + state->application_state = + std::make_unique(*iter->second->application_state); + } + if (GetQuicReloadableFlag(quic_tls_use_token_in_session_cache) && + !iter->second->token.empty()) { + state->token = iter->second->token; + // Clear token after use. + iter->second->token.clear(); + } + + return state; +} + +void QuicClientSessionCache::ClearEarlyData(const QuicServerId& server_id) { + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) return; + for (auto& session : iter->second->sessions) { + if (session) { + QUIC_DLOG(INFO) << "Clear early data for for host: " << server_id.host(); + session.reset(SSL_SESSION_copy_without_early_data(session.get())); + } + } +} + +void QuicClientSessionCache::OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) { + if (token.empty()) { + return; + } + auto iter = cache_.Lookup(server_id); + if (iter == cache_.end()) { + return; + } + iter->second->token = std::string(token); +} + +void QuicClientSessionCache::RemoveExpiredEntries(QuicWallTime now) { + auto iter = cache_.begin(); + while (iter != cache_.end()) { + if (!IsValid(iter->second->PeekSession(), now.ToUNIXSeconds())) { + iter = cache_.Erase(iter); + } else { + ++iter; + } + } +} + +void QuicClientSessionCache::Clear() { cache_.Clear(); } + +void QuicClientSessionCache::CreateAndInsertEntry( + const QuicServerId& server_id, bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) { + auto entry = std::make_unique(); + entry->PushSession(std::move(session)); + entry->params = std::make_unique(params); + if (application_state) { + entry->application_state = + std::make_unique(*application_state); + } + cache_.Insert(server_id, std::move(entry)); +} + +QuicClientSessionCache::Entry::Entry() = default; +QuicClientSessionCache::Entry::Entry(Entry&&) = default; +QuicClientSessionCache::Entry::~Entry() = default; + +void QuicClientSessionCache::Entry::PushSession( + bssl::UniquePtr session) { + if (sessions[0] != nullptr) { + sessions[1] = std::move(sessions[0]); + } + sessions[0] = std::move(session); +} + +bssl::UniquePtr QuicClientSessionCache::Entry::PopSession() { + if (sessions[0] == nullptr) return nullptr; + bssl::UniquePtr session = std::move(sessions[0]); + sessions[0] = std::move(sessions[1]); + sessions[1] = nullptr; + return session; +} + +SSL_SESSION* QuicClientSessionCache::Entry::PeekSession() { + return sessions[0].get(); +} + +} // namespace quic diff --git a/gquiche/quic/core/crypto/quic_client_session_cache.h b/gquiche/quic/core/crypto/quic_client_session_cache.h new file mode 100644 index 00000000..75956ca3 --- /dev/null +++ b/gquiche/quic/core/crypto/quic_client_session_cache.h @@ -0,0 +1,82 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_CRYPTO_QUIC_CLIENT_SESSION_CACHE_H_ +#define QUICHE_QUIC_CORE_CRYPTO_QUIC_CLIENT_SESSION_CACHE_H_ + +#include + +#include "gquiche/quic/core/crypto/quic_crypto_client_config.h" +#include "gquiche/quic/core/quic_lru_cache.h" +#include "gquiche/quic/core/quic_server_id.h" + +namespace quic { + +namespace test { +class QuicClientSessionCachePeer; +} // namespace test + +// QuicClientSessionCache maps from QuicServerId to information used to resume +// TLS sessions for that server. +class QUIC_EXPORT_PRIVATE QuicClientSessionCache : public SessionCache { + public: + QuicClientSessionCache(); + explicit QuicClientSessionCache(size_t max_entries); + ~QuicClientSessionCache() override; + + void Insert(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state) override; + + std::unique_ptr Lookup(const QuicServerId& server_id, + QuicWallTime now, + const SSL_CTX* ctx) override; + + void ClearEarlyData(const QuicServerId& server_id) override; + + void OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) override; + + void RemoveExpiredEntries(QuicWallTime now) override; + + void Clear() override; + + size_t size() const { return cache_.Size(); } + + private: + friend class test::QuicClientSessionCachePeer; + + struct QUIC_EXPORT_PRIVATE Entry { + Entry(); + Entry(Entry&&); + ~Entry(); + + // Adds a new |session| onto sessions, dropping the oldest one if two are + // already stored. + void PushSession(bssl::UniquePtr session); + + // Retrieves the latest session from the entry, meanwhile removing it. + bssl::UniquePtr PopSession(); + + SSL_SESSION* PeekSession(); + + bssl::UniquePtr sessions[2]; + std::unique_ptr params; + std::unique_ptr application_state; + std::string token; // An opaque string received in NEW_TOKEN frame. + }; + + // Creates a new entry and insert into |cache_|. + void CreateAndInsertEntry(const QuicServerId& server_id, + bssl::UniquePtr session, + const TransportParameters& params, + const ApplicationState* application_state); + + QuicLRUCache cache_; +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_CRYPTO_QUIC_CLIENT_SESSION_CACHE_H_ diff --git a/gquiche/quic/core/crypto/quic_client_session_cache_test.cc b/gquiche/quic/core/crypto/quic_client_session_cache_test.cc new file mode 100644 index 00000000..0e91a373 --- /dev/null +++ b/gquiche/quic/core/crypto/quic_client_session_cache_test.cc @@ -0,0 +1,440 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/crypto/quic_client_session_cache.h" + +#include "gquiche/quic/platform/api/quic_test.h" +#include "gquiche/quic/test_tools/mock_clock.h" +#include "gquiche/common/quiche_text_utils.h" + +namespace quic { +namespace test { +namespace { + +const QuicTime::Delta kTimeout = QuicTime::Delta::FromSeconds(1000); +const QuicVersionLabel kFakeVersionLabel = 0x01234567; +const QuicVersionLabel kFakeVersionLabel2 = 0x89ABCDEF; +const uint64_t kFakeIdleTimeoutMilliseconds = 12012; +const uint8_t kFakeStatelessResetTokenData[16] = { + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9A, 0x9B, 0x9C, 0x9D, 0x9E, 0x9F}; +const uint64_t kFakeMaxPacketSize = 9001; +const uint64_t kFakeInitialMaxData = 101; +const bool kFakeDisableMigration = true; +const auto kCustomParameter1 = + static_cast(0xffcd); +const char* kCustomParameter1Value = "foo"; +const auto kCustomParameter2 = + static_cast(0xff34); +const char* kCustomParameter2Value = "bar"; + +std::vector CreateFakeStatelessResetToken() { + return std::vector( + kFakeStatelessResetTokenData, + kFakeStatelessResetTokenData + sizeof(kFakeStatelessResetTokenData)); +} + +TransportParameters::LegacyVersionInformation +CreateFakeLegacyVersionInformation() { + TransportParameters::LegacyVersionInformation legacy_version_information; + legacy_version_information.version = kFakeVersionLabel; + legacy_version_information.supported_versions.push_back(kFakeVersionLabel); + legacy_version_information.supported_versions.push_back(kFakeVersionLabel2); + return legacy_version_information; +} + +TransportParameters::VersionInformation CreateFakeVersionInformation() { + TransportParameters::VersionInformation version_information; + version_information.chosen_version = kFakeVersionLabel; + version_information.other_versions.push_back(kFakeVersionLabel); + return version_information; +} + +// Make a TransportParameters that has a few fields set to help test comparison. +std::unique_ptr MakeFakeTransportParams() { + auto params = std::make_unique(); + params->perspective = Perspective::IS_CLIENT; + params->legacy_version_information = CreateFakeLegacyVersionInformation(); + params->version_information = CreateFakeVersionInformation(); + params->max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + params->stateless_reset_token = CreateFakeStatelessResetToken(); + params->max_udp_payload_size.set_value(kFakeMaxPacketSize); + params->initial_max_data.set_value(kFakeInitialMaxData); + params->disable_active_migration = kFakeDisableMigration; + params->custom_parameters[kCustomParameter1] = kCustomParameter1Value; + params->custom_parameters[kCustomParameter2] = kCustomParameter2Value; + return params; +} + +// Generated by running TlsClientHandshakerTest.ZeroRttResumption and in +// TlsClientHandshaker::InsertSession calling SSL_SESSION_to_bytes to serialize +// the received 0-RTT capable ticket. +static const char kCachedSession[] = + "30820ad7020101020203040402130104206594ce84e61a866b56163c4ba09079aebf1d4f" + "6cbcbd38dc9d7066a38a76c9cf0420ec9062063582a4cc0a44f9ff93256a195153ba6032" + "0cf3c9189990932d838adaa10602046196f7b9a205020302a300a382039f3082039b3082" + "0183a00302010202021001300d06092a864886f70d010105050030623111300f06035504" + "030c08426f677573204941310b300906035504080c024d41310b30090603550406130255" + "533121301f06092a864886f70d0109011612626f67757340626f6775732d69612e636f6d" + "3110300e060355040a0c07426f6775734941301e170d3231303132383136323030315a17" + "0d3331303132363136323030315a3069311d301b06035504030c14746573745f6563632e" + "6578616d706c652e636f6d310b300906035504080c024d41310b30090603550406130255" + "53311e301c06092a864886f70d010901160f626f67757340626f6775732e636f6d310e30" + "0c060355040a0c05426f6775733059301306072a8648ce3d020106082a8648ce3d030107" + "034200041ba5e2b6f24e64990b9f24ae6d23473d8c77fbcfb7f554f36559529a69a57170" + "a10a81b7fe4a36ebf37b0a8c5e467a8443d8b8c002892aa5c1194bd843f42c9aa31f301d" + "301b0603551d11041430128210746573742e6578616d706c652e636f6d300d06092a8648" + "86f70d0101050500038202010019921d54ac06948763d609215f64f5d6540e3da886c6c9" + "61bc737a437719b4621416ef1229f39282d7d3234e1a5d57535473066233bd246eec8e96" + "1e0633cf4fe014c800e62599981820ec33d92e74ded0fa2953db1d81e19cb6890b6305b6" + "3ede8d3e9fcf3c09f3f57283acf08aa57be4ee9a68d00bb3e2ded5920c619b5d83e5194a" + "adb77ae5d61ed3e0a5670f0ae61cc3197329f0e71e3364dcab0405e9e4a6646adef8f022" + "6415ec16c8046307b1769029fe780bd576114dde2fa9b4a32aa70bc436549a24ee4907a9" + "045f6457ce8dfd8d62cc65315afe798ae1a948eefd70b035d415e73569c48fb20085de1a" + "87de039e6b0b9a5fcb4069df27f3a7a1409e72d1ac739c72f29ef786134207e61c79855f" + "c22e3ee5f6ad59a7b1ff0f18d79776f1c95efaebbebe381664132a58a1e7ff689945b7e0" + "88634b0872feeefbf6be020884b994c6a7ff435f2b3f609077ff97cb509cfa17ff479b34" + "e633e4b5bc46b20c5f27c80a2e2943f795a928acd5a3fc43c3af8425ad600c048b41d87e" + "6361bc72fc4e5e44680a3d325674ba6ffa760d2fc7d9e4847a8e0dd9d35a543324e18b94" + "2d42af6391ed1dd54a39e3f4a4c6b32486eb4ba72815dbd89c56fc053743a0b0483ce676" + "15defce6800c629b99d0cbc56da162487f475b7c246099eaf1e6d10a022b2f49c6af1da3" + "e8ed66096f267c4a76976b9572db7456ef90278330a4020400aa81b60481b3494e534543" + "55524500f3439e548c21d2ad6e5634cc1cc0045730819702010102020304040213010400" + "0420ec9062063582a4cc0a44f9ff93256a195153ba60320cf3c9189990932d838adaa106" + "02046196f7b9a205020302a300a4020400b20302011db5060404130800cdb807020500ff" + "ffffffb9050203093a80ba0404026833bb030101ffbc23042100d27d985bfce04833f02d" + "38366b219f4def42bc4ba1b01844d1778db11731487dbd020400be020400b20302011db3" + "8205da308205d6308203bea00302010202021000300d06092a864886f70d010105050030" + "62310b3009060355040613025553310b300906035504080c024d413110300e060355040a" + "0c07426f67757343413111300f06035504030c08426f6775732043413121301f06092a86" + "4886f70d0109011612626f67757340626f6775732d63612e636f6d3020170d3231303132" + "383136313935385a180f32303730303531313136313935385a30623111300f0603550403" + "0c08426f677573204941310b300906035504080c024d41310b3009060355040613025553" + "3121301f06092a864886f70d0109011612626f67757340626f6775732d69612e636f6d31" + "10300e060355040a0c07426f677573494130820222300d06092a864886f70d0101010500" + "0382020f003082020a028202010096c03a0ffc61bcedcd5ec9bf6f848b8a066b43f08377" + "3af518a6a0044f22e666e24d2ae741954e344302c4be04612185bd53bcd848eb322bf900" + "724eb0848047d647033ffbddb00f01d1de7c1cdb684f83c9bf5fd18ff60afad5a53b0d7d" + "2c2a50abc38df019cd7f50194d05bc4597a1ef8570ea04069a2c36d74496af126573ca18" + "8e470009b56250fadf2a04e837ee3837b36b1f08b7a0cfe2533d05f26484ce4e30203d01" + "517fffd3da63d0341079ddce16e9ab4dbf9d4049e5cc52326031e645dd682fe6220d9e0e" + "95451f5a82f3e1720dc13e8499466426a0bdbea9f6a76b3c9228dd3c79ab4dcc4c145ef0" + "e78d1ee8bfd4650692d7e28a54bed809d8f7b37fe24c586be59cc46638531cb291c8c156" + "8f08d67e768e51563e95a639c1f138b275ffad6a6a2a042ba9e26ad63c2ce63b600013f0" + "a6f0703ee51c4f457f7bab0391c2fc4c5bb3213742c9cf9941bff68cc2e1cc96139d35ed" + "1885244ddde0bf658416c486701841b81f7b17503d08c59a4db08a2a80755e007aa3b6c7" + "eadcaa9e07c8325f3689f100de23970b12c9d9f6d0a8fb35ba0fd75c64410318db4a13ac" + "3972ad16cdf6408af37013c7bcd7c42f20d6d04c3e39436c7531e8dafa219dd04b784ef0" + "3c70ee5a4782b33cafa925aa3deca62a14aed704f179b932efabc2b0c5c15a8a99bfc9e6" + "189dce7da50ea303594b6af9c933dd54b6e9d17c472d0203010001a38193308190300f06" + "03551d130101ff040530030101ff301d0603551d0e041604141a98e80029a80992b7e5e0" + "068ab9b3486cd839d6301f0603551d23041830168014780beeefe2fa419c48a438bdb30b" + "e37ef0b7a94e300b0603551d0f0404030202a430130603551d25040c300a06082b060105" + "05070301301b0603551d11041430128207426f67757343418207426f6775734941300d06" + "092a864886f70d010105050003820201009e822ed8064b1aabaddf1340010ea147f68c06" + "5a5a599ea305349f1b0e545a00817d6e55c7bf85560fab429ca72186c4d520b52f5cc121" + "abd068b06f3111494431d2522efa54642f907059e7db80b73bb5ecf621377195b8700bba" + "df798cece8c67a9571548d0e6592e81ae5d934877cb170aef18d3b97f635600fe0890d98" + "f88b33fe3d1fd34c1c915beae4e5c0b133f476c40b21d220f16ce9cdd9e8f97a36a31723" + "68875f052c9271648d9cb54687c6fdc3ea96f2908003bc5e5e79de00a21da7b8429f8b08" + "af4c4d34641e386d72eabf5f01f106363f2ffd18969bf0bb9a4d17627c6427ff772c4308" + "83c276feef5fc6dba9582c22fdbe9df7e8dfca375695f028ed588df54f3c86462dbf4c07" + "91d80ca738988a1419c86bb4dd8d738b746921f01f39422e5ffd488b6f00195b996e6392" + "3a820a32cd78b5989f339c0fcf4f269103964a30a16347d0ffdc8df1f3653ddc1515fa09" + "22c7aef1af1fbcb23e93ae7622ab1ee11fcfa98319bad4c37c091cad46bd0337b3cc78b5" + "5b9f1ea7994acc1f89c49a0b4cb540d2137e266fd43e56a9b5b778217b6f77df530e1eaf" + "b3417262b5ddb86d3c6c5ac51e3f326c650dcc2434473973b7182c66220d1f3871bde7ee" + "47d3f359d3d4c5bdd61baa684c03db4c75f9d6690c9e6e3abe6eaf5fa2c33c4daf26b373" + "d85a1e8a7d671ac4a0a97b14e36e81280de4593bbb12da7695b5060404130800cdb60301" + "0100b70402020403b807020500ffffffffb9050203093a80ba0404026833bb030101ffbd" + "020400be020400"; + +class QuicClientSessionCacheTest : public QuicTest { + public: + QuicClientSessionCacheTest() : ssl_ctx_(SSL_CTX_new(TLS_method())) { + clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); + } + + protected: + bssl::UniquePtr NewSSLSession() { + std::string cached_session = + absl::HexStringToBytes(absl::string_view(kCachedSession)); + SSL_SESSION* session = SSL_SESSION_from_bytes( + reinterpret_cast(cached_session.data()), + cached_session.size(), ssl_ctx_.get()); + QUICHE_DCHECK(session); + return bssl::UniquePtr(session); + } + + bssl::UniquePtr MakeTestSession( + QuicTime::Delta timeout = kTimeout) { + bssl::UniquePtr session = NewSSLSession(); + SSL_SESSION_set_time(session.get(), clock_.WallNow().ToUNIXSeconds()); + SSL_SESSION_set_timeout(session.get(), timeout.ToSeconds()); + return session; + } + + bssl::UniquePtr ssl_ctx_; + MockClock clock_; +}; + +// Tests that simple insertion and lookup work correctly. +TEST_F(QuicClientSessionCacheTest, SingleSession) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + + auto params2 = MakeFakeTransportParams(); + auto session2 = MakeTestSession(); + SSL_SESSION* unowned2 = session2.get(); + QuicServerId id2("b.com", 443); + + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(nullptr, cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(0u, cache.size()); + + cache.Insert(id1, std::move(session), *params, nullptr); + EXPECT_EQ(1u, cache.size()); + EXPECT_EQ( + *params, + *(cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())->transport_params)); + EXPECT_EQ(nullptr, cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())); + // No session is available for id1, even though the entry exists. + EXPECT_EQ(1u, cache.size()); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + // Lookup() will trigger a deletion of invalid entry. + EXPECT_EQ(0u, cache.size()); + + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + QuicServerId id3("c.com", 443); + cache.Insert(id3, std::move(session3), *params, nullptr); + cache.Insert(id2, std::move(session2), *params2, nullptr); + EXPECT_EQ(2u, cache.size()); + EXPECT_EQ( + unowned2, + cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ( + unowned3, + cache.Lookup(id3, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + + // Verify that the cache is cleared after Lookups. + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(nullptr, cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(nullptr, cache.Lookup(id3, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(0u, cache.size()); +} + +TEST_F(QuicClientSessionCacheTest, MultipleSessions) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + SSL_SESSION* unowned2 = session2.get(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id1, std::move(session2), *params, nullptr); + cache.Insert(id1, std::move(session3), *params, nullptr); + // The latest session is popped first. + EXPECT_EQ( + unowned3, + cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ( + unowned2, + cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + // Only two sessions are cached. + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +// Test that when a different TransportParameter is inserted for +// the same server id, the existing entry is removed. +TEST_F(QuicClientSessionCacheTest, DifferentTransportParams) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id1, std::move(session2), *params, nullptr); + // tweak the transport parameters a little bit. + params->perspective = Perspective::IS_SERVER; + cache.Insert(id1, std::move(session3), *params, nullptr); + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_EQ(unowned3, resumption_state->tls_session.get()); + EXPECT_EQ(*params.get(), *resumption_state->transport_params); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +TEST_F(QuicClientSessionCacheTest, DifferentApplicationState) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + ApplicationState state; + state.push_back('a'); + + cache.Insert(id1, std::move(session), *params, &state); + cache.Insert(id1, std::move(session2), *params, &state); + cache.Insert(id1, std::move(session3), *params, nullptr); + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_EQ(unowned3, resumption_state->tls_session.get()); + EXPECT_EQ(nullptr, resumption_state->application_state); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +TEST_F(QuicClientSessionCacheTest, BothStatesDifferent) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + ApplicationState state; + state.push_back('a'); + + cache.Insert(id1, std::move(session), *params, &state); + cache.Insert(id1, std::move(session2), *params, &state); + params->perspective = Perspective::IS_SERVER; + cache.Insert(id1, std::move(session3), *params, nullptr); + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_EQ(unowned3, resumption_state->tls_session.get()); + EXPECT_EQ(*params.get(), *resumption_state->transport_params); + EXPECT_EQ(nullptr, resumption_state->application_state); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +// When the size limit is exceeded, the oldest entry should be erased. +TEST_F(QuicClientSessionCacheTest, SizeLimit) { + QuicClientSessionCache cache(2); + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + + auto session2 = MakeTestSession(); + SSL_SESSION* unowned2 = session2.get(); + QuicServerId id2("b.com", 443); + + auto session3 = MakeTestSession(); + SSL_SESSION* unowned3 = session3.get(); + QuicServerId id3("c.com", 443); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id2, std::move(session2), *params, nullptr); + cache.Insert(id3, std::move(session3), *params, nullptr); + + EXPECT_EQ(2u, cache.size()); + EXPECT_EQ( + unowned2, + cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ( + unowned3, + cache.Lookup(id3, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +TEST_F(QuicClientSessionCacheTest, ClearEarlyData) { + QuicClientSessionCache cache; + SSL_CTX_set_early_data_enabled(ssl_ctx_.get(), 1); + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + auto session2 = MakeTestSession(); + + EXPECT_TRUE(SSL_SESSION_early_data_capable(session.get())); + EXPECT_TRUE(SSL_SESSION_early_data_capable(session2.get())); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id1, std::move(session2), *params, nullptr); + + cache.ClearEarlyData(id1); + + auto resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_FALSE( + SSL_SESSION_early_data_capable(resumption_state->tls_session.get())); + resumption_state = cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get()); + EXPECT_FALSE( + SSL_SESSION_early_data_capable(resumption_state->tls_session.get())); + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); +} + +// Expired session isn't considered valid and nullptr will be returned upon +// Lookup. +TEST_F(QuicClientSessionCacheTest, Expiration) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + QuicServerId id1("a.com", 443); + + auto session2 = MakeTestSession(3 * kTimeout); + SSL_SESSION* unowned2 = session2.get(); + QuicServerId id2("b.com", 443); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id2, std::move(session2), *params, nullptr); + + EXPECT_EQ(2u, cache.size()); + // Expire the session. + clock_.AdvanceTime(kTimeout * 2); + // The entry has not been removed yet. + EXPECT_EQ(2u, cache.size()); + + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(1u, cache.size()); + EXPECT_EQ( + unowned2, + cache.Lookup(id2, clock_.WallNow(), ssl_ctx_.get())->tls_session.get()); + EXPECT_EQ(1u, cache.size()); +} + +TEST_F(QuicClientSessionCacheTest, RemoveExpiredEntriesAndClear) { + QuicClientSessionCache cache; + + auto params = MakeFakeTransportParams(); + auto session = MakeTestSession(); + quic::QuicServerId id1("a.com", 443); + + auto session2 = MakeTestSession(3 * kTimeout); + quic::QuicServerId id2("b.com", 443); + + cache.Insert(id1, std::move(session), *params, nullptr); + cache.Insert(id2, std::move(session2), *params, nullptr); + + EXPECT_EQ(2u, cache.size()); + // Expire the session. + clock_.AdvanceTime(kTimeout * 2); + // The entry has not been removed yet. + EXPECT_EQ(2u, cache.size()); + + // Flush expired sessions. + cache.RemoveExpiredEntries(clock_.WallNow()); + + // session is expired and should be flushed. + EXPECT_EQ(nullptr, cache.Lookup(id1, clock_.WallNow(), ssl_ctx_.get())); + EXPECT_EQ(1u, cache.size()); + + cache.Clear(); + EXPECT_EQ(0u, cache.size()); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/gquiche/quic/core/crypto/quic_compressed_certs_cache.cc b/gquiche/quic/core/crypto/quic_compressed_certs_cache.cc index 6991c3af..a3363546 100644 --- a/gquiche/quic/core/crypto/quic_compressed_certs_cache.cc +++ b/gquiche/quic/core/crypto/quic_compressed_certs_cache.cc @@ -81,7 +81,11 @@ const std::string* QuicCompressedCertsCache::GetCompressedCert( uint64_t key = ComputeUncompressedCertsHash(uncompressed_certs); - CachedCerts* cached_value = certs_cache_.Lookup(key); + CachedCerts* cached_value = nullptr; + auto iter = certs_cache_.Lookup(key); + if (iter != certs_cache_.end()) { + cached_value = iter->second.get(); + } if (cached_value != nullptr && cached_value->MatchesUncompressedCerts(uncompressed_certs)) { return cached_value->compressed_cert(); diff --git a/gquiche/quic/core/crypto/quic_crypto_client_config.cc b/gquiche/quic/core/crypto/quic_crypto_client_config.cc index 6fb51c30..df5a0945 100644 --- a/gquiche/quic/core/crypto/quic_crypto_client_config.cc +++ b/gquiche/quic/core/crypto/quic_crypto_client_config.cc @@ -33,8 +33,6 @@ #include "gquiche/quic/platform/api/quic_client_stats.h" #include "gquiche/quic/platform/api/quic_hostname_utils.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { @@ -441,7 +439,9 @@ void QuicCryptoClientConfig::FillInchoateClientHello( out->SetVector(kPDMD, QuicTagVector{kX509}); - if (common_cert_sets) { + if (GetQuicRestartFlag(quic_no_common_cert_set)) { + // Client only. No flag count. + } else if (common_cert_sets) { out->SetStringPiece(kCCS, common_cert_sets->GetCommonHashes()); } @@ -803,12 +803,12 @@ SessionCache* QuicCryptoClientConfig::session_cache() const { return session_cache_.get(); } -ProofSource* QuicCryptoClientConfig::proof_source() const { +ClientProofSource* QuicCryptoClientConfig::proof_source() const { return proof_source_.get(); } void QuicCryptoClientConfig::set_proof_source( - std::unique_ptr proof_source) { + std::unique_ptr proof_source) { proof_source_ = std::move(proof_source); } @@ -849,22 +849,23 @@ bool QuicCryptoClientConfig::PopulateFromCanonicalConfig( QuicServerId suffix_server_id(canonical_suffixes_[i], server_id.port(), server_id.privacy_mode_enabled()); - if (!QuicContainsKey(canonical_server_map_, suffix_server_id)) { + auto it = canonical_server_map_.lower_bound(suffix_server_id); + if (it == canonical_server_map_.end() || it->first != suffix_server_id) { // This is the first host we've seen which matches the suffix, so make it - // canonical. - canonical_server_map_[suffix_server_id] = server_id; + // canonical. Use |it| as position hint for faster insertion. + canonical_server_map_.insert( + it, std::make_pair(std::move(suffix_server_id), std::move(server_id))); return false; } - const QuicServerId& canonical_server_id = - canonical_server_map_[suffix_server_id]; + const QuicServerId& canonical_server_id = it->second; CachedState* canonical_state = cached_states_[canonical_server_id].get(); if (!canonical_state->proof_valid()) { return false; } // Update canonical version to point at the "most recent" entry. - canonical_server_map_[suffix_server_id] = server_id; + it->second = server_id; server_state->InitializeFrom(*canonical_state); return true; diff --git a/gquiche/quic/core/crypto/quic_crypto_client_config.h b/gquiche/quic/core/crypto/quic_crypto_client_config.h index a411387c..8c70a953 100644 --- a/gquiche/quic/core/crypto/quic_crypto_client_config.h +++ b/gquiche/quic/core/crypto/quic_crypto_client_config.h @@ -14,9 +14,9 @@ #include "absl/strings/string_view.h" #include "openssl/base.h" #include "openssl/ssl.h" +#include "gquiche/quic/core/crypto/client_proof_source.h" #include "gquiche/quic/core/crypto/crypto_handshake.h" #include "gquiche/quic/core/crypto/crypto_protocol.h" -#include "gquiche/quic/core/crypto/proof_source.h" #include "gquiche/quic/core/crypto/transport_parameters.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_server_id.h" @@ -49,6 +49,9 @@ struct QUIC_EXPORT_PRIVATE QuicResumptionState { // client received from the server at the application layer that the client // needs to remember when performing a 0-RTT handshake. std::unique_ptr application_state = nullptr; + + // Opaque token received in NEW_TOKEN frame if any. + std::string token; }; // SessionCache is an interface for managing storing and retrieving @@ -74,12 +77,21 @@ class QUIC_EXPORT_PRIVATE SessionCache { // delete cache entries after returning them in Lookup so that session tickets // are used only once. virtual std::unique_ptr Lookup( - const QuicServerId& server_id, - const SSL_CTX* ctx) = 0; + const QuicServerId& server_id, QuicWallTime now, const SSL_CTX* ctx) = 0; // Called when 0-RTT is rejected. Disables early data for all the TLS tickets // associated with |server_id|. virtual void ClearEarlyData(const QuicServerId& server_id) = 0; + + // Called when NEW_TOKEN frame is received. + virtual void OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) = 0; + + // Called to remove expired entries. + virtual void RemoveExpiredEntries(QuicWallTime now) = 0; + + // Clear the session cache. + virtual void Clear() = 0; }; // QuicCryptoClientConfig contains crypto-related configuration settings for a @@ -336,8 +348,8 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { ProofVerifier* proof_verifier() const; SessionCache* session_cache() const; - ProofSource* proof_source() const; - void set_proof_source(std::unique_ptr proof_source); + ClientProofSource* proof_source() const; + void set_proof_source(std::unique_ptr proof_source); SSL_CTX* ssl_ctx() const; // Initialize the CachedState from |canonical_crypto_config| for the @@ -362,6 +374,14 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { // handshake message. const std::string& user_agent_id() const { return user_agent_id_; } + void set_tls_signature_algorithms(std::string signature_algorithms) { + tls_signature_algorithms_ = std::move(signature_algorithms); + } + + const absl::optional& tls_signature_algorithms() const { + return tls_signature_algorithms_; + } + // Saves the |alpn| that will be passed in QUIC's CHLO message. void set_alpn(const std::string& alpn) { alpn_ = alpn; } @@ -381,6 +401,8 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { bool pad_full_hello() const { return pad_full_hello_; } void set_pad_full_hello(bool new_value) { pad_full_hello_ = new_value; } + SessionCache* mutable_session_cache() { return session_cache_.get(); } + private: // Sets the members to reasonable, default values. void SetDefaults(); @@ -421,7 +443,7 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { std::unique_ptr proof_verifier_; std::unique_ptr session_cache_; - std::unique_ptr proof_source_; + std::unique_ptr proof_source_; bssl::UniquePtr ssl_ctx_; @@ -435,6 +457,10 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientConfig : public QuicCryptoConfig { // incorporating |pre_shared_key_| into the key schedule. std::string pre_shared_key_; + // If set, configure the client to use the specified signature algorithms, via + // SSL_set1_sigalgs_list. TLS only. + absl::optional tls_signature_algorithms_; + // In QUIC, technically, client hello should be fully padded. // However, fully padding on slow network connection (e.g. 50kbps) can add // 150ms latency to one roundtrip. Therefore, you can disable padding of diff --git a/gquiche/quic/core/crypto/quic_crypto_client_config_test.cc b/gquiche/quic/core/crypto/quic_crypto_client_config_test.cc index e6a92ee6..86075c9a 100644 --- a/gquiche/quic/core/crypto/quic_crypto_client_config_test.cc +++ b/gquiche/quic/core/crypto/quic_crypto_client_config_test.cc @@ -179,6 +179,7 @@ TEST_F(QuicCryptoClientConfigTest, InchoateChloSecureWithSCIDNoEXPY) { QuicWallTime expiry = QuicWallTime::FromUNIXSeconds(2); state.SetServerConfig(scfg.GetSerialized().AsStringPiece(), now, expiry, &details); + EXPECT_FALSE(state.IsEmpty()); QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); QuicReferenceCountedPointer params( @@ -205,6 +206,7 @@ TEST_F(QuicCryptoClientConfigTest, InchoateChloSecureWithSCID) { state.SetServerConfig(scfg.GetSerialized().AsStringPiece(), QuicWallTime::FromUNIXSeconds(1), QuicWallTime::FromUNIXSeconds(0), &details); + EXPECT_FALSE(state.IsEmpty()); QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); QuicReferenceCountedPointer params( @@ -503,5 +505,46 @@ TEST_F(QuicCryptoClientConfigTest, ServerNonceinSHLO) { EXPECT_EQ("server hello missing server nonce", error_details); } +// Test that PopulateFromCanonicalConfig() handles the case of multiple entries +// in |canonical_server_map_|. +TEST_F(QuicCryptoClientConfigTest, MultipleCanonicalEntries) { + QuicCryptoClientConfig config(crypto_test_utils::ProofVerifierForTesting()); + config.AddCanonicalSuffix(".google.com"); + QuicServerId canonical_server_id1("www.google.com", 443, false); + QuicCryptoClientConfig::CachedState* state1 = + config.LookupOrCreate(canonical_server_id1); + + CryptoHandshakeMessage scfg; + scfg.set_tag(kSCFG); + scfg.SetStringPiece(kSCID, "12345678"); + std::string details; + QuicWallTime now = QuicWallTime::FromUNIXSeconds(1); + QuicWallTime expiry = QuicWallTime::FromUNIXSeconds(2); + state1->SetServerConfig(scfg.GetSerialized().AsStringPiece(), now, expiry, + &details); + state1->set_source_address_token("TOKEN"); + state1->SetProofValid(); + EXPECT_FALSE(state1->IsEmpty()); + + // This will have the same |suffix_server_id| as |canonical_server_id1|, + // therefore |*state2| will be initialized from |*state1|. + QuicServerId canonical_server_id2("mail.google.com", 443, false); + QuicCryptoClientConfig::CachedState* state2 = + config.LookupOrCreate(canonical_server_id2); + EXPECT_FALSE(state2->IsEmpty()); + const CryptoHandshakeMessage* const scfg2 = state2->GetServerConfig(); + ASSERT_TRUE(scfg2); + EXPECT_EQ(kSCFG, scfg2->tag()); + + // With a different |suffix_server_id|, this will return an empty CachedState. + config.AddCanonicalSuffix(".example.com"); + QuicServerId canonical_server_id3("www.example.com", 443, false); + QuicCryptoClientConfig::CachedState* state3 = + config.LookupOrCreate(canonical_server_id3); + EXPECT_TRUE(state3->IsEmpty()); + const CryptoHandshakeMessage* const scfg3 = state3->GetServerConfig(); + EXPECT_FALSE(scfg3); +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/core/crypto/quic_crypto_proof.cc b/gquiche/quic/core/crypto/quic_crypto_proof.cc index 5bee5f83..da91ac24 100644 --- a/gquiche/quic/core/crypto/quic_crypto_proof.cc +++ b/gquiche/quic/core/crypto/quic_crypto_proof.cc @@ -6,6 +6,7 @@ namespace quic { -QuicCryptoProof::QuicCryptoProof() : send_expect_ct_header(false) {} +QuicCryptoProof::QuicCryptoProof() + : send_expect_ct_header(false), cert_matched_sni(false) {} } // namespace quic diff --git a/gquiche/quic/core/crypto/quic_crypto_proof.h b/gquiche/quic/core/crypto/quic_crypto_proof.h index f36f870f..df13a0ec 100644 --- a/gquiche/quic/core/crypto/quic_crypto_proof.h +++ b/gquiche/quic/core/crypto/quic_crypto_proof.h @@ -22,6 +22,9 @@ struct QUIC_EXPORT_PRIVATE QuicCryptoProof { // Should the Expect-CT header be sent on the connection where the // certificate is used. bool send_expect_ct_header; + // Did the selected leaf certificate contain a SubjectAltName that included + // the requested SNI. + bool cert_matched_sni; }; } // namespace quic diff --git a/gquiche/quic/core/crypto/quic_crypto_server_config.cc b/gquiche/quic/core/crypto/quic_crypto_server_config.cc index 4ee77bc6..66339e0a 100644 --- a/gquiche/quic/core/crypto/quic_crypto_server_config.cc +++ b/gquiche/quic/core/crypto/quic_crypto_server_config.cc @@ -29,12 +29,10 @@ #include "gquiche/quic/core/crypto/key_exchange.h" #include "gquiche/quic/core/crypto/p256_key_exchange.h" #include "gquiche/quic/core/crypto/proof_source.h" -#include "gquiche/quic/core/crypto/proof_verifier.h" #include "gquiche/quic/core/crypto/quic_decrypter.h" #include "gquiche/quic/core/crypto/quic_encrypter.h" #include "gquiche/quic/core/crypto/quic_hkdf.h" #include "gquiche/quic/core/crypto/quic_random.h" -#include "gquiche/quic/core/crypto/server_proof_verifier.h" #include "gquiche/quic/core/crypto/tls_server_connection.h" #include "gquiche/quic/core/proto/crypto_server_config_proto.h" #include "gquiche/quic/core/proto/source_address_token_proto.h" @@ -51,7 +49,6 @@ #include "gquiche/quic/platform/api/quic_reference_counted.h" #include "gquiche/quic/platform/api/quic_socket_address.h" #include "gquiche/quic/platform/api/quic_testvalue.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { @@ -245,7 +242,6 @@ QuicCryptoServerConfig::QuicCryptoServerConfig( primary_config_(nullptr), next_config_promotion_time_(QuicWallTime::Zero()), proof_source_(std::move(proof_source)), - client_cert_mode_(ClientCertMode::kNone), key_exchange_source_(std::move(key_exchange_source)), ssl_ctx_(TlsServerConnection::CreateSslCtx(proof_source_.get())), source_address_token_future_secs_(3600), @@ -745,9 +741,7 @@ void QuicCryptoServerConfig::ProcessClientHelloAfterGetProof( << context->connection_id() << " which is invalid with version " << context->version(); - if (context->validate_chlo_result()->postpone_cert_validate_for_server && - context->info().reject_reasons.empty()) { - QUIC_RELOADABLE_FLAG_COUNT(quic_crypto_postpone_cert_validate_for_server); + if (context->info().reject_reasons.empty()) { if (!context->signed_config() || !context->signed_config()->chain) { // No chain. context->validate_chlo_result()->info.reject_reasons.push_back( @@ -1224,8 +1218,8 @@ void QuicCryptoServerConfig::SelectNewPrimaryConfig( } void QuicCryptoServerConfig::EvaluateClientHello( - const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, + const QuicSocketAddress& /*server_address*/, + const QuicSocketAddress& /*client_address*/, QuicTransportVersion /*version*/, const Configs& configs, QuicReferenceCountedPointer @@ -1253,7 +1247,7 @@ void QuicCryptoServerConfig::EvaluateClientHello( configs.requested != nullptr ? *configs.requested : *configs.primary; source_address_token_error = ParseSourceAddressToken(*config.source_address_token_boxer, srct, - &info->source_address_tokens); + info->source_address_tokens); if (source_address_token_error == HANDSHAKE_OK) { source_address_token_error = ValidateSourceAddressTokens( @@ -1294,17 +1288,6 @@ void QuicCryptoServerConfig::EvaluateClientHello( // No valid source address token. } - if (!client_hello_state->postpone_cert_validate_for_server) { - QuicReferenceCountedPointer chain = - proof_source_->GetCertChain(server_address, client_address, - std::string(info->sni)); - if (!chain) { - info->reject_reasons.push_back(SERVER_CONFIG_UNKNOWN_CONFIG_FAILURE); - } else if (!ValidateExpectedLeafCertificate(client_hello, chain->certs)) { - info->reject_reasons.push_back(INVALID_EXPECTED_LEAF_CERTIFICATE); - } - } - if (info->client_nonce.size() != kNonceSize) { info->reject_reasons.push_back(CLIENT_NONCE_INVALID_FAILURE); // Invalid client nonce. @@ -1642,15 +1625,6 @@ QuicCryptoServerConfig::ParseConfigProtobuf( static_assert(sizeof(config->orbit) == kOrbitSize, "incorrect orbit size"); memcpy(config->orbit, orbit.data(), sizeof(config->orbit)); - if ((kexs_tags.size() != static_cast(protobuf.key_size())) && - (!GetQuicRestartFlag(dont_fetch_quic_private_keys_from_leto) && - protobuf.key_size() == 0)) { - QUIC_LOG(WARNING) << "Server config has " << kexs_tags.size() - << " key exchange methods configured, but " - << protobuf.key_size() << " private keys"; - return nullptr; - } - QuicTagVector proof_demand_tags; if (msg->GetTaglist(kPDMD, &proof_demand_tags) == QUIC_NO_ERROR) { for (QuicTag tag : proof_demand_tags) { @@ -1775,38 +1749,20 @@ ProofSource* QuicCryptoServerConfig::proof_source() const { return proof_source_.get(); } -ServerProofVerifier* QuicCryptoServerConfig::proof_verifier() const { - return proof_verifier_.get(); -} - -void QuicCryptoServerConfig::set_proof_verifier( - std::unique_ptr proof_verifier) { - proof_verifier_ = std::move(proof_verifier); -} - -ClientCertMode QuicCryptoServerConfig::client_cert_mode() const { - return client_cert_mode_; -} - -void QuicCryptoServerConfig::set_client_cert_mode(ClientCertMode mode) { - client_cert_mode_ = mode; -} - SSL_CTX* QuicCryptoServerConfig::ssl_ctx() const { return ssl_ctx_.get(); } HandshakeFailureReason QuicCryptoServerConfig::ParseSourceAddressToken( - const CryptoSecretBoxer& crypto_secret_boxer, - absl::string_view token, - SourceAddressTokens* tokens) const { + const CryptoSecretBoxer& crypto_secret_boxer, absl::string_view token, + SourceAddressTokens& tokens) const { std::string storage; absl::string_view plaintext; if (!crypto_secret_boxer.Unbox(token, &storage, &plaintext)) { return SOURCE_ADDRESS_TOKEN_DECRYPTION_FAILURE; } - if (!tokens->ParseFromArray(plaintext.data(), plaintext.size())) { + if (!tokens.ParseFromArray(plaintext.data(), plaintext.size())) { // Some clients might still be using the old source token format so // attempt to parse that format. // TODO(rch): remove this code once the new format is ubiquitous. @@ -1814,7 +1770,7 @@ HandshakeFailureReason QuicCryptoServerConfig::ParseSourceAddressToken( if (!token.ParseFromArray(plaintext.data(), plaintext.size())) { return SOURCE_ADDRESS_TOKEN_PARSE_FAILURE; } - *tokens->add_tokens() = token; + *tokens.add_tokens() = token; } return HANDSHAKE_OK; @@ -1830,7 +1786,8 @@ HandshakeFailureReason QuicCryptoServerConfig::ValidateSourceAddressTokens( for (const SourceAddressToken& token : source_address_tokens.tokens()) { reason = ValidateSingleSourceAddressToken(token, ip, now); if (reason == HANDSHAKE_OK) { - if (token.has_cached_network_parameters()) { + if (cached_network_params != nullptr && + token.has_cached_network_parameters()) { *cached_network_params = token.cached_network_parameters(); } break; diff --git a/gquiche/quic/core/crypto/quic_crypto_server_config.h b/gquiche/quic/core/crypto/quic_crypto_server_config.h index 7def2c04..b68094b4 100644 --- a/gquiche/quic/core/crypto/quic_crypto_server_config.h +++ b/gquiche/quic/core/crypto/quic_crypto_server_config.h @@ -20,10 +20,8 @@ #include "gquiche/quic/core/crypto/crypto_secret_boxer.h" #include "gquiche/quic/core/crypto/key_exchange.h" #include "gquiche/quic/core/crypto/proof_source.h" -#include "gquiche/quic/core/crypto/proof_verifier.h" #include "gquiche/quic/core/crypto/quic_compressed_certs_cache.h" #include "gquiche/quic/core/crypto/quic_crypto_proof.h" -#include "gquiche/quic/core/crypto/server_proof_verifier.h" #include "gquiche/quic/core/proto/cached_network_parameters_proto.h" #include "gquiche/quic/core/proto/source_address_token_proto.h" #include "gquiche/quic/core/quic_time.h" @@ -98,9 +96,6 @@ class QUIC_EXPORT_PRIVATE ValidateClientHelloResultCallback { // Populated if the CHLO STK contained a CachedNetworkParameters proto. CachedNetworkParameters cached_network_params; - const bool postpone_cert_validate_for_server = - GetQuicReloadableFlag(quic_crypto_postpone_cert_validate_for_server); - protected: ~Result() override; }; @@ -436,9 +431,8 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerConfig { // Returns HANDSHAKE_OK if |token| could be parsed, or the reason for the // failure. HandshakeFailureReason ParseSourceAddressToken( - const CryptoSecretBoxer& crypto_secret_boxer, - absl::string_view token, - SourceAddressTokens* tokens) const; + const CryptoSecretBoxer& crypto_secret_boxer, absl::string_view token, + SourceAddressTokens& tokens) const; // ValidateSourceAddressTokens returns HANDSHAKE_OK if the source address // tokens in |tokens| contain a valid and timely token for the IP address @@ -458,11 +452,6 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerConfig { } ProofSource* proof_source() const; - ServerProofVerifier* proof_verifier() const; - void set_proof_verifier(std::unique_ptr proof_verifier); - - ClientCertMode client_cert_mode() const; - void set_client_cert_mode(ClientCertMode client_cert_mode); SSL_CTX* ssl_ctx() const; @@ -510,9 +499,6 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerConfig { // one-to-one, with the tags in |kexs| from the parent class. std::vector> key_exchanges; - // tag_value_map contains the raw key/value pairs for the config. - QuicTagValueMap tag_value_map; - // channel_id_enabled is true if the config in |serialized| specifies that // ChannelIDs are supported. bool channel_id_enabled; @@ -926,8 +912,6 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerConfig { // proof_source_ contains an object that can provide certificate chains and // signatures. std::unique_ptr proof_source_; - std::unique_ptr proof_verifier_; - ClientCertMode client_cert_mode_; // key_exchange_source_ contains an object that can provide key exchange // objects. diff --git a/gquiche/quic/core/crypto/quic_crypto_server_config_test.cc b/gquiche/quic/core/crypto/quic_crypto_server_config_test.cc index c087ef51..89b6f56e 100644 --- a/gquiche/quic/core/crypto/quic_crypto_server_config_test.cc +++ b/gquiche/quic/core/crypto/quic_crypto_server_config_test.cc @@ -25,7 +25,6 @@ #include "gquiche/quic/test_tools/mock_clock.h" #include "gquiche/quic/test_tools/quic_crypto_server_config_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { namespace test { diff --git a/gquiche/quic/core/crypto/quic_hkdf_test.cc b/gquiche/quic/core/crypto/quic_hkdf_test.cc index 04cec094..349557d6 100644 --- a/gquiche/quic/core/crypto/quic_hkdf_test.cc +++ b/gquiche/quic/core/crypto/quic_hkdf_test.cc @@ -9,7 +9,6 @@ #include "absl/base/macros.h" #include "absl/strings/escaping.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { namespace test { diff --git a/gquiche/quic/core/crypto/quic_random.cc b/gquiche/quic/core/crypto/quic_random.cc index becb8385..8ee07095 100644 --- a/gquiche/quic/core/crypto/quic_random.cc +++ b/gquiche/quic/core/crypto/quic_random.cc @@ -19,17 +19,22 @@ namespace { // xoshiro256++ 1.0 based on code in the public domain from // . +inline uint64_t Xoshiro256InitializeRngStateMember() { + uint64_t result; + RAND_bytes(reinterpret_cast(&result), sizeof(result)); + return result; +} + inline uint64_t Xoshiro256PlusPlusRotLeft(uint64_t x, int k) { return (x << k) | (x >> (64 - k)); } uint64_t Xoshiro256PlusPlus() { - static thread_local uint64_t rng_state[4]; - static thread_local bool rng_state_initialized = false; - if (QUIC_PREDICT_FALSE(!rng_state_initialized)) { - RAND_bytes(reinterpret_cast(&rng_state), sizeof(rng_state)); - rng_state_initialized = true; - } + static thread_local uint64_t rng_state[4] = { + Xoshiro256InitializeRngStateMember(), + Xoshiro256InitializeRngStateMember(), + Xoshiro256InitializeRngStateMember(), + Xoshiro256InitializeRngStateMember()}; const uint64_t result = Xoshiro256PlusPlusRotLeft(rng_state[0] + rng_state[3], 23) + rng_state[0]; const uint64_t t = rng_state[1] << 17; diff --git a/gquiche/quic/core/crypto/tls_client_connection.cc b/gquiche/quic/core/crypto/tls_client_connection.cc index 344c1a85..f9936bb8 100644 --- a/gquiche/quic/core/crypto/tls_client_connection.cc +++ b/gquiche/quic/core/crypto/tls_client_connection.cc @@ -6,16 +6,20 @@ namespace quic { -TlsClientConnection::TlsClientConnection(SSL_CTX* ssl_ctx, Delegate* delegate) - : TlsConnection(ssl_ctx, delegate->ConnectionDelegate()), +TlsClientConnection::TlsClientConnection(SSL_CTX* ssl_ctx, + Delegate* delegate, + QuicSSLConfig ssl_config) + : TlsConnection(ssl_ctx, + delegate->ConnectionDelegate(), + std::move(ssl_config)), delegate_(delegate) {} // static bssl::UniquePtr TlsClientConnection::CreateSslCtx( bool enable_early_data) { - bssl::UniquePtr ssl_ctx = - TlsConnection::CreateSslCtx(SSL_VERIFY_PEER); + bssl::UniquePtr ssl_ctx = TlsConnection::CreateSslCtx(); // Configure certificate verification. + SSL_CTX_set_custom_verify(ssl_ctx.get(), SSL_VERIFY_PEER, &VerifyCallback); int reverify_on_resume_enabled = 1; SSL_CTX_set_reverify_on_resume(ssl_ctx.get(), reverify_on_resume_enabled); @@ -24,10 +28,18 @@ bssl::UniquePtr TlsClientConnection::CreateSslCtx( ssl_ctx.get(), SSL_SESS_CACHE_CLIENT | SSL_SESS_CACHE_NO_INTERNAL); SSL_CTX_sess_set_new_cb(ssl_ctx.get(), NewSessionCallback); + // TODO(wub): Always enable early data on the SSL_CTX, but allow it to be + // overridden on the SSL object, via QuicSSLConfig. SSL_CTX_set_early_data_enabled(ssl_ctx.get(), enable_early_data); return ssl_ctx; } +void TlsClientConnection::SetCertChain( + const std::vector& cert_chain, EVP_PKEY* privkey) { + SSL_set_chain_and_key(ssl(), cert_chain.data(), cert_chain.size(), privkey, + /*privkey_method=*/nullptr); +} + // static int TlsClientConnection::NewSessionCallback(SSL* ssl, SSL_SESSION* session) { static_cast(ConnectionFromSsl(ssl)) diff --git a/gquiche/quic/core/crypto/tls_client_connection.h b/gquiche/quic/core/crypto/tls_client_connection.h index 3d2d2916..11fcb649 100644 --- a/gquiche/quic/core/crypto/tls_client_connection.h +++ b/gquiche/quic/core/crypto/tls_client_connection.h @@ -30,11 +30,18 @@ class QUIC_EXPORT_PRIVATE TlsClientConnection : public TlsConnection { friend class TlsClientConnection; }; - TlsClientConnection(SSL_CTX* ssl_ctx, Delegate* delegate); + TlsClientConnection(SSL_CTX* ssl_ctx, + Delegate* delegate, + QuicSSLConfig ssl_config); // Creates and configures an SSL_CTX that is appropriate for clients to use. static bssl::UniquePtr CreateSslCtx(bool enable_early_data); + // Set the client cert and private key to be used on this connection, if + // requested by the server. + void SetCertChain(const std::vector& cert_chain, + EVP_PKEY* privkey); + private: // Registered as the callback for SSL_CTX_sess_set_new_cb, which calls // Delegate::InsertSession. diff --git a/gquiche/quic/core/crypto/tls_connection.cc b/gquiche/quic/core/crypto/tls_connection.cc index de748317..767af71f 100644 --- a/gquiche/quic/core/crypto/tls_connection.cc +++ b/gquiche/quic/core/crypto/tls_connection.cc @@ -5,6 +5,7 @@ #include "gquiche/quic/core/crypto/tls_connection.h" #include "absl/strings/string_view.h" +#include "openssl/ssl.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" namespace quic { @@ -88,22 +89,44 @@ enum ssl_encryption_level_t TlsConnection::BoringEncryptionLevel( } TlsConnection::TlsConnection(SSL_CTX* ssl_ctx, - TlsConnection::Delegate* delegate) - : delegate_(delegate), ssl_(SSL_new(ssl_ctx)) { + TlsConnection::Delegate* delegate, + QuicSSLConfig ssl_config) + : delegate_(delegate), + ssl_(SSL_new(ssl_ctx)), + ssl_config_(std::move(ssl_config)) { SSL_set_ex_data( ssl(), SslIndexSingleton::GetInstance()->ssl_ex_data_index_connection(), this); + if (ssl_config_.early_data_enabled.has_value()) { + const int early_data_enabled = *ssl_config_.early_data_enabled ? 1 : 0; + SSL_set_early_data_enabled(ssl(), early_data_enabled); + } + if (ssl_config_.signing_algorithm_prefs.has_value()) { + SSL_set_signing_algorithm_prefs( + ssl(), ssl_config_.signing_algorithm_prefs->data(), + ssl_config_.signing_algorithm_prefs->size()); + } + if (ssl_config.disable_ticket_support.has_value()) { + if (*ssl_config.disable_ticket_support) { + SSL_set_options(ssl(), SSL_OP_NO_TICKET); + } + } } + +void TlsConnection::EnableInfoCallback() { + SSL_set_info_callback( + ssl(), +[](const SSL* ssl, int type, int value) { + ConnectionFromSsl(ssl)->delegate_->InfoCallback(type, value); + }); +} + // static -bssl::UniquePtr TlsConnection::CreateSslCtx(int cert_verify_mode) { +bssl::UniquePtr TlsConnection::CreateSslCtx() { CRYPTO_library_init(); bssl::UniquePtr ssl_ctx(SSL_CTX_new(TLS_with_buffers_method())); SSL_CTX_set_min_proto_version(ssl_ctx.get(), TLS1_3_VERSION); SSL_CTX_set_max_proto_version(ssl_ctx.get(), TLS1_3_VERSION); SSL_CTX_set_quic_method(ssl_ctx.get(), &kSslQuicMethod); - if (cert_verify_mode != SSL_VERIFY_NONE) { - SSL_CTX_set_custom_verify(ssl_ctx.get(), cert_verify_mode, &VerifyCallback); - } return ssl_ctx; } diff --git a/gquiche/quic/core/crypto/tls_connection.h b/gquiche/quic/core/crypto/tls_connection.h index a42c8327..ca9a3f32 100644 --- a/gquiche/quic/core/crypto/tls_connection.h +++ b/gquiche/quic/core/crypto/tls_connection.h @@ -75,12 +75,21 @@ class QUIC_EXPORT_PRIVATE TlsConnection { // level |level|. virtual void SendAlert(EncryptionLevel level, uint8_t desc) = 0; + // Informational callback from BoringSSL. This callback is disabled by + // default, but can be enabled by TlsConnection::EnableInfoCallback. + // + // See |SSL_CTX_set_info_callback| for the meaning of |type| and |value|. + virtual void InfoCallback(int type, int value) = 0; + friend class TlsConnection; }; TlsConnection(const TlsConnection&) = delete; TlsConnection& operator=(const TlsConnection&) = delete; + // Configure the SSL such that delegate_->InfoCallback will be called. + void EnableInfoCallback(); + // Functions to convert between BoringSSL's enum ssl_encryption_level_t and // QUIC's EncryptionLevel. static EncryptionLevel QuicEncryptionLevel(enum ssl_encryption_level_t level); @@ -89,20 +98,17 @@ class QUIC_EXPORT_PRIVATE TlsConnection { SSL* ssl() const { return ssl_.get(); } + const QuicSSLConfig& ssl_config() const { return ssl_config_; } + protected: - // TlsConnection does not take ownership of any of its arguments; they must + // TlsConnection does not take ownership of |ssl_ctx| or |delegate|; they must // outlive the TlsConnection object. - TlsConnection(SSL_CTX* ssl_ctx, Delegate* delegate); + TlsConnection(SSL_CTX* ssl_ctx, Delegate* delegate, QuicSSLConfig ssl_config); // Creates an SSL_CTX and configures it with the options that are appropriate // for both client and server. The caller is responsible for ownership of the // newly created struct. - // - // The provided |cert_verify_mode| is passed in as the |mode| argument for - // |SSL_CTX_set_verify|. See - // https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#SSL_VERIFY_NONE - // for a description of possible values. - static bssl::UniquePtr CreateSslCtx(int cert_verify_mode); + static bssl::UniquePtr CreateSslCtx(); // From a given SSL* |ssl|, returns a pointer to the TlsConnection that it // belongs to. This helper method allows the callbacks set in BoringSSL to be @@ -110,11 +116,13 @@ class QUIC_EXPORT_PRIVATE TlsConnection { // callback. static TlsConnection* ConnectionFromSsl(const SSL* ssl); - private: - // Registered as the callback for SSL_CTX_set_custom_verify. The + // Registered as the callback for SSL(_CTX)_set_custom_verify. The // implementation is delegated to Delegate::VerifyCert. static enum ssl_verify_result_t VerifyCallback(SSL* ssl, uint8_t* out_alert); + QuicSSLConfig& mutable_ssl_config() { return ssl_config_; } + + private: // TlsConnection implements SSL_QUIC_METHOD, which provides the interface // between BoringSSL's TLS stack and a QUIC implementation. static const SSL_QUIC_METHOD kSslQuicMethod; @@ -141,6 +149,7 @@ class QUIC_EXPORT_PRIVATE TlsConnection { Delegate* delegate_; bssl::UniquePtr ssl_; + QuicSSLConfig ssl_config_; }; } // namespace quic diff --git a/gquiche/quic/core/crypto/tls_server_connection.cc b/gquiche/quic/core/crypto/tls_server_connection.cc index a1f7b100..4749130f 100644 --- a/gquiche/quic/core/crypto/tls_server_connection.cc +++ b/gquiche/quic/core/crypto/tls_server_connection.cc @@ -7,20 +7,33 @@ #include "absl/strings/string_view.h" #include "openssl/ssl.h" #include "gquiche/quic/core/crypto/proof_source.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" namespace quic { -TlsServerConnection::TlsServerConnection(SSL_CTX* ssl_ctx, Delegate* delegate) - : TlsConnection(ssl_ctx, delegate->ConnectionDelegate()), - delegate_(delegate) {} +TlsServerConnection::TlsServerConnection(SSL_CTX* ssl_ctx, Delegate* delegate, + QuicSSLConfig ssl_config) + : TlsConnection(ssl_ctx, delegate->ConnectionDelegate(), + std::move(ssl_config)), + delegate_(delegate) { + // By default, cert verify callback is not installed on ssl(), so only need to + // UpdateCertVerifyCallback() if client_cert_mode is not kNone. + if (TlsConnection::ssl_config().client_cert_mode != ClientCertMode::kNone) { + UpdateCertVerifyCallback(); + } +} // static bssl::UniquePtr TlsServerConnection::CreateSslCtx( ProofSource* proof_source) { - bssl::UniquePtr ssl_ctx = - TlsConnection::CreateSslCtx(SSL_VERIFY_NONE); + bssl::UniquePtr ssl_ctx = TlsConnection::CreateSslCtx(); + + // Server does not request/verify client certs by default. Individual server + // connections may call SSL_set_custom_verify on their SSL object to request + // client certs. + SSL_CTX_set_tlsext_servername_callback(ssl_ctx.get(), &TlsExtServernameCallback); SSL_CTX_set_alpn_select_cb(ssl_ctx.get(), &SelectAlpnCallback, nullptr); @@ -30,16 +43,12 @@ bssl::UniquePtr TlsServerConnection::CreateSslCtx( QUIC_CODE_COUNT(quic_session_tickets_enabled); SSL_CTX_set_ticket_aead_method(ssl_ctx.get(), &TlsServerConnection::kSessionTicketMethod); - } else if (!GetQuicRestartFlag(quic_session_tickets_always_enabled)) { - QUIC_CODE_COUNT(quic_session_tickets_disabled_by_flag); - SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_NO_TICKET); } else { QUIC_CODE_COUNT(quic_session_tickets_disabled); } - if (proof_source->GetTicketCrypter() || - GetQuicRestartFlag(quic_session_tickets_always_enabled)) { - SSL_CTX_set_early_data_enabled(ssl_ctx.get(), 1); - } + + SSL_CTX_set_early_data_enabled(ssl_ctx.get(), 1); + SSL_CTX_set_select_certificate_cb( ssl_ctx.get(), &TlsServerConnection::EarlySelectCertCallback); SSL_CTX_set_options(ssl_ctx.get(), SSL_OP_CIPHER_SERVER_PREFERENCE); @@ -52,6 +61,31 @@ void TlsServerConnection::SetCertChain( &TlsServerConnection::kPrivateKeyMethod); } +void TlsServerConnection::SetClientCertMode(ClientCertMode client_cert_mode) { + if (ssl_config().client_cert_mode == client_cert_mode) { + return; + } + + mutable_ssl_config().client_cert_mode = client_cert_mode; + UpdateCertVerifyCallback(); +} + +void TlsServerConnection::UpdateCertVerifyCallback() { + const ClientCertMode client_cert_mode = ssl_config().client_cert_mode; + if (client_cert_mode == ClientCertMode::kNone) { + SSL_set_custom_verify(ssl(), SSL_VERIFY_NONE, nullptr); + return; + } + + int mode = SSL_VERIFY_PEER; + if (client_cert_mode == ClientCertMode::kRequire) { + mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + } else { + QUICHE_DCHECK_EQ(client_cert_mode, ClientCertMode::kRequest); + } + SSL_set_custom_verify(ssl(), mode, &VerifyCallback); +} + const SSL_PRIVATE_KEY_METHOD TlsServerConnection::kPrivateKeyMethod{ &TlsServerConnection::PrivateKeySign, nullptr, // decrypt diff --git a/gquiche/quic/core/crypto/tls_server_connection.h b/gquiche/quic/core/crypto/tls_server_connection.h index 5c942165..84e8fec2 100644 --- a/gquiche/quic/core/crypto/tls_server_connection.h +++ b/gquiche/quic/core/crypto/tls_server_connection.h @@ -120,13 +120,20 @@ class QUIC_EXPORT_PRIVATE TlsServerConnection : public TlsConnection { friend class TlsServerConnection; }; - TlsServerConnection(SSL_CTX* ssl_ctx, Delegate* delegate); + TlsServerConnection(SSL_CTX* ssl_ctx, + Delegate* delegate, + QuicSSLConfig ssl_config); // Creates and configures an SSL_CTX that is appropriate for servers to use. static bssl::UniquePtr CreateSslCtx(ProofSource* proof_source); void SetCertChain(const std::vector& cert_chain); + // Set the client cert mode to be used on this connection. This should be + // called right after cert selection at the latest, otherwise it is too late + // to has an effect. + void SetClientCertMode(ClientCertMode client_cert_mode); + private: // Specialization of TlsConnection::ConnectionFromSsl. static TlsServerConnection* ConnectionFromSsl(SSL* ssl); @@ -181,6 +188,10 @@ class QUIC_EXPORT_PRIVATE TlsServerConnection : public TlsConnection { const uint8_t* in, size_t in_len); + // Install custom verify callback on ssl() if |ssl_config().client_cert_mode| + // is not ClientCertMode::kNone. Uninstall otherwise. + void UpdateCertVerifyCallback(); + Delegate* delegate_; }; diff --git a/gquiche/quic/core/crypto/transport_parameters.cc b/gquiche/quic/core/crypto/transport_parameters.cc index d75b948a..b5e28b3e 100644 --- a/gquiche/quic/core/crypto/transport_parameters.cc +++ b/gquiche/quic/core/crypto/transport_parameters.cc @@ -22,6 +22,7 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" namespace quic { @@ -61,7 +62,8 @@ enum TransportParameters::TransportParameterId : uint64_t { kGoogleQuicVersion = 0x4752, // Used to transmit version and supported_versions. - kMinAckDelay = 0xDE1A, // draft-iyengar-quic-delayed-ack. + kMinAckDelay = 0xDE1A, // draft-iyengar-quic-delayed-ack. + kVersionInformation = 0xFF73DB, // draft-ietf-quic-version-negotiation. }; namespace { @@ -128,6 +130,8 @@ std::string TransportParameterIdToString( return "google-version"; case TransportParameters::kMinAckDelay: return "min_ack_delay_us"; + case TransportParameters::kVersionInformation: + return "version_information"; } return absl::StrCat("Unknown(", param_id, ")"); } @@ -155,11 +159,15 @@ bool TransportParameterIdIsKnown( case TransportParameters::kMaxDatagramFrameSize: case TransportParameters::kInitialRoundTripTime: case TransportParameters::kGoogleConnectionOptions: - case TransportParameters::kGoogleUserAgentId: - case TransportParameters::kGoogleKeyUpdateNotYetSupported: case TransportParameters::kGoogleQuicVersion: case TransportParameters::kMinAckDelay: return true; + case TransportParameters::kVersionInformation: + return GetQuicReloadableFlag(quic_version_information); + case TransportParameters::kGoogleUserAgentId: + return !GetQuicReloadableFlag(quic_ignore_user_agent_transport_parameter); + case TransportParameters::kGoogleKeyUpdateNotYetSupported: + return !GetQuicReloadableFlag(quic_ignore_key_update_not_yet_supported); } return false; } @@ -306,6 +314,70 @@ std::string TransportParameters::PreferredAddress::ToString() const { "]"; } +TransportParameters::LegacyVersionInformation::LegacyVersionInformation() + : version(0) {} + +bool TransportParameters::LegacyVersionInformation::operator==( + const LegacyVersionInformation& rhs) const { + return version == rhs.version && supported_versions == rhs.supported_versions; +} + +bool TransportParameters::LegacyVersionInformation::operator!=( + const LegacyVersionInformation& rhs) const { + return !(*this == rhs); +} + +std::string TransportParameters::LegacyVersionInformation::ToString() const { + std::string rv = + absl::StrCat("legacy[version ", QuicVersionLabelToString(version)); + if (!supported_versions.empty()) { + absl::StrAppend(&rv, + " supported_versions " + + QuicVersionLabelVectorToString(supported_versions)); + } + absl::StrAppend(&rv, "]"); + return rv; +} + +std::ostream& operator<<(std::ostream& os, + const TransportParameters::LegacyVersionInformation& + legacy_version_information) { + os << legacy_version_information.ToString(); + return os; +} + +TransportParameters::VersionInformation::VersionInformation() + : chosen_version(0) {} + +bool TransportParameters::VersionInformation::operator==( + const VersionInformation& rhs) const { + return chosen_version == rhs.chosen_version && + other_versions == rhs.other_versions; +} + +bool TransportParameters::VersionInformation::operator!=( + const VersionInformation& rhs) const { + return !(*this == rhs); +} + +std::string TransportParameters::VersionInformation::ToString() const { + std::string rv = absl::StrCat("[chosen_version ", + QuicVersionLabelToString(chosen_version)); + if (!other_versions.empty()) { + absl::StrAppend(&rv, " other_versions " + + QuicVersionLabelVectorToString(other_versions)); + } + absl::StrAppend(&rv, "]"); + return rv; +} + +std::ostream& operator<<( + std::ostream& os, + const TransportParameters::VersionInformation& version_information) { + os << version_information.ToString(); + return os; +} + std::ostream& operator<<(std::ostream& os, const TransportParameters& params) { os << params.ToString(); return os; @@ -318,12 +390,11 @@ std::string TransportParameters::ToString() const { } else { rv += "Client"; } - if (version != 0) { - rv += " version " + QuicVersionLabelToString(version); + if (legacy_version_information.has_value()) { + rv += " " + legacy_version_information.value().ToString(); } - if (!supported_versions.empty()) { - rv += " supported_versions " + - QuicVersionLabelVectorToString(supported_versions); + if (version_information.has_value()) { + rv += " " + version_information.value().ToString(); } if (original_destination_connection_id.has_value()) { rv += " " + TransportParameterIdToString(kOriginalDestinationConnectionId) + @@ -400,12 +471,9 @@ std::string TransportParameters::ToString() const { } TransportParameters::TransportParameters() - : version(0), - max_idle_timeout_ms(kMaxIdleTimeout), - max_udp_payload_size(kMaxPacketSize, - kDefaultMaxPacketSizeTransportParam, - kMinMaxPacketSizeTransportParam, - kVarInt62MaxValue), + : max_idle_timeout_ms(kMaxIdleTimeout), + max_udp_payload_size(kMaxPacketSize, kDefaultMaxPacketSizeTransportParam, + kMinMaxPacketSizeTransportParam, kVarInt62MaxValue), initial_max_data(kInitialMaxData), initial_max_stream_data_bidi_local(kInitialMaxStreamDataBidiLocal), initial_max_stream_data_bidi_remote(kInitialMaxStreamDataBidiRemote), @@ -413,16 +481,11 @@ TransportParameters::TransportParameters() initial_max_streams_bidi(kInitialMaxStreamsBidi), initial_max_streams_uni(kInitialMaxStreamsUni), ack_delay_exponent(kAckDelayExponent, - kDefaultAckDelayExponentTransportParam, - 0, + kDefaultAckDelayExponentTransportParam, 0, kMaxAckDelayExponentTransportParam), - max_ack_delay(kMaxAckDelay, - kDefaultMaxAckDelayTransportParam, - 0, + max_ack_delay(kMaxAckDelay, kDefaultMaxAckDelayTransportParam, 0, kMaxMaxAckDelayTransportParam), - min_ack_delay_us(kMinAckDelay, - 0, - 0, + min_ack_delay_us(kMinAckDelay, 0, 0, kMaxMaxAckDelayTransportParam * kNumMicrosPerMilli), disable_active_migration(false), active_connection_id_limit(kActiveConnectionIdLimit, @@ -440,8 +503,8 @@ TransportParameters::TransportParameters() TransportParameters::TransportParameters(const TransportParameters& other) : perspective(other.perspective), - version(other.version), - supported_versions(other.supported_versions), + legacy_version_information(other.legacy_version_information), + version_information(other.version_information), original_destination_connection_id( other.original_destination_connection_id), max_idle_timeout_ms(other.max_idle_timeout_ms), @@ -475,8 +538,9 @@ TransportParameters::TransportParameters(const TransportParameters& other) } bool TransportParameters::operator==(const TransportParameters& rhs) const { - if (!(perspective == rhs.perspective && version == rhs.version && - supported_versions == rhs.supported_versions && + if (!(perspective == rhs.perspective && + legacy_version_information == rhs.legacy_version_information && + version_information == rhs.version_information && original_destination_connection_id == rhs.original_destination_connection_id && max_idle_timeout_ms.value() == rhs.max_idle_timeout_ms.value() && @@ -586,6 +650,29 @@ bool TransportParameters::AreValid(std::string* error_details) const { *error_details = "Server cannot send user agent ID"; return false; } + if (version_information.has_value()) { + const QuicVersionLabel& chosen_version = + version_information.value().chosen_version; + const QuicVersionLabelVector& other_versions = + version_information.value().other_versions; + if (chosen_version == 0) { + *error_details = "Invalid chosen version"; + return false; + } + if (perspective == Perspective::IS_CLIENT && + std::find(other_versions.begin(), other_versions.end(), + chosen_version) == other_versions.end()) { + // When sent by the client, chosen_version needs to be present in + // other_versions because other_versions lists the compatible versions and + // the chosen version is part of that list. When sent by the server, + // other_version contains the list of fully-deployed versions which is + // generally equal to the list of supported versions but can slightly + // differ during removal of versions across a server fleet. See + // draft-ietf-quic-version-negotiation for details. + *error_details = "Client chosen version not in other versions"; + return false; + } + } const bool ok = max_idle_timeout_ms.IsValid() && max_udp_payload_size.IsValid() && initial_max_data.IsValid() && @@ -609,15 +696,27 @@ bool SerializeTransportParameters(ParsedQuicVersion /*version*/, std::vector* out) { std::string error_details; if (!in.AreValid(&error_details)) { - QUIC_BUG(quic_bug_10743_5) + QUIC_BUG(invalid transport parameters) << "Not serializing invalid transport parameters: " << error_details; return false; } - if (in.version == 0 || (in.perspective == Perspective::IS_SERVER && - in.supported_versions.empty())) { - QUIC_BUG(quic_bug_10743_6) << "Refusing to serialize without versions"; + if (!in.legacy_version_information.has_value() || + in.legacy_version_information.value().version == 0 || + (in.perspective == Perspective::IS_SERVER && + in.legacy_version_information.value().supported_versions.empty())) { + QUIC_BUG(missing versions) << "Refusing to serialize without versions"; return false; } + TransportParameters::ParameterMap custom_parameters = in.custom_parameters; + for (const auto& kv : custom_parameters) { + if (kv.first % 31 == 27) { + // See the "Reserved Transport Parameters" section of RFC 9000. + QUIC_BUG(custom_parameters with GREASE) + << "Serializing custom_parameters with GREASE ID " << kv.first + << " is not allowed"; + return false; + } + } // Maximum length of the GREASE transport parameter (see below). static constexpr size_t kMaxGreaseLength = 16; @@ -637,8 +736,6 @@ bool SerializeTransportParameters(ParsedQuicVersion /*version*/, kTypeAndValueLength + 4 /*IPv4 address */ + 2 /* IPv4 port */ + 16 /* IPv6 address */ + 1 /* Connection ID length */ + 255 /* maximum connection ID length */ + 16 /* stateless reset token */; - static constexpr size_t kGreaseParameterLength = - kTypeAndValueLength + kMaxGreaseLength; static constexpr size_t kKnownTransportParamLength = kConnectionIdParameterLength + // original_destination_connection_id kIntegerParameterLength + // max_idle_timeout @@ -663,8 +760,35 @@ bool SerializeTransportParameters(ParsedQuicVersion /*version*/, kTypeAndValueLength + // google_connection_options kTypeAndValueLength + // user_agent_id kTypeAndValueLength + // key_update_not_yet_supported - kTypeAndValueLength + // google-version - kGreaseParameterLength; // GREASE + kTypeAndValueLength; // google-version + + std::vector parameter_ids = { + TransportParameters::kOriginalDestinationConnectionId, + TransportParameters::kMaxIdleTimeout, + TransportParameters::kStatelessResetToken, + TransportParameters::kMaxPacketSize, + TransportParameters::kInitialMaxData, + TransportParameters::kInitialMaxStreamDataBidiLocal, + TransportParameters::kInitialMaxStreamDataBidiRemote, + TransportParameters::kInitialMaxStreamDataUni, + TransportParameters::kInitialMaxStreamsBidi, + TransportParameters::kInitialMaxStreamsUni, + TransportParameters::kAckDelayExponent, + TransportParameters::kMaxAckDelay, + TransportParameters::kMinAckDelay, + TransportParameters::kActiveConnectionIdLimit, + TransportParameters::kMaxDatagramFrameSize, + TransportParameters::kInitialRoundTripTime, + TransportParameters::kDisableActiveMigration, + TransportParameters::kPreferredAddress, + TransportParameters::kInitialSourceConnectionId, + TransportParameters::kRetrySourceConnectionId, + TransportParameters::kGoogleConnectionOptions, + TransportParameters::kGoogleUserAgentId, + TransportParameters::kGoogleKeyUpdateNotYetSupported, + TransportParameters::kGoogleQuicVersion, + TransportParameters::kVersionInformation, + }; size_t max_transport_param_length = kKnownTransportParamLength; // google_connection_options. @@ -677,272 +801,444 @@ bool SerializeTransportParameters(ParsedQuicVersion /*version*/, max_transport_param_length += in.user_agent_id.value().length(); } // Google-specific version extension. - max_transport_param_length += - sizeof(in.version) + 1 /* versions length */ + - in.supported_versions.size() * sizeof(QuicVersionLabel); - // Custom parameters. - for (const auto& kv : in.custom_parameters) { - max_transport_param_length += kTypeAndValueLength + kv.second.length(); - } - - out->resize(max_transport_param_length); - QuicDataWriter writer(out->size(), reinterpret_cast(out->data())); - - // original_destination_connection_id - if (in.original_destination_connection_id.has_value()) { - QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); - QuicConnectionId original_destination_connection_id = - in.original_destination_connection_id.value(); - if (!writer.WriteVarInt62( - TransportParameters::kOriginalDestinationConnectionId) || - !writer.WriteStringPieceVarInt62( - absl::string_view(original_destination_connection_id.data(), - original_destination_connection_id.length()))) { - QUIC_BUG(quic_bug_10743_7) - << "Failed to write original_destination_connection_id " - << original_destination_connection_id << " for " << in; - return false; - } - } - - if (!in.max_idle_timeout_ms.Write(&writer)) { - QUIC_BUG(quic_bug_10743_8) << "Failed to write idle_timeout for " << in; - return false; - } - - // stateless_reset_token - if (!in.stateless_reset_token.empty()) { - QUICHE_DCHECK_EQ(kStatelessResetTokenLength, - in.stateless_reset_token.size()); - QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); - if (!writer.WriteVarInt62(TransportParameters::kStatelessResetToken) || - !writer.WriteStringPieceVarInt62(absl::string_view( - reinterpret_cast(in.stateless_reset_token.data()), - in.stateless_reset_token.size()))) { - QUIC_BUG(quic_bug_10743_9) - << "Failed to write stateless_reset_token of length " - << in.stateless_reset_token.size() << " for " << in; - return false; - } - } - - if (!in.max_udp_payload_size.Write(&writer) || - !in.initial_max_data.Write(&writer) || - !in.initial_max_stream_data_bidi_local.Write(&writer) || - !in.initial_max_stream_data_bidi_remote.Write(&writer) || - !in.initial_max_stream_data_uni.Write(&writer) || - !in.initial_max_streams_bidi.Write(&writer) || - !in.initial_max_streams_uni.Write(&writer) || - !in.ack_delay_exponent.Write(&writer) || - !in.max_ack_delay.Write(&writer) || !in.min_ack_delay_us.Write(&writer) || - !in.active_connection_id_limit.Write(&writer) || - !in.max_datagram_frame_size.Write(&writer) || - !in.initial_round_trip_time_us.Write(&writer)) { - QUIC_BUG(quic_bug_10743_10) << "Failed to write integers for " << in; - return false; - } - - // disable_active_migration - if (in.disable_active_migration) { - if (!writer.WriteVarInt62(TransportParameters::kDisableActiveMigration) || - !writer.WriteVarInt62(/* transport parameter length */ 0)) { - QUIC_BUG(quic_bug_10743_11) - << "Failed to write disable_active_migration for " << in; - return false; - } - } - - // preferred_address - if (in.preferred_address) { - std::string v4_address_bytes = - in.preferred_address->ipv4_socket_address.host().ToPackedString(); - std::string v6_address_bytes = - in.preferred_address->ipv6_socket_address.host().ToPackedString(); - if (v4_address_bytes.length() != 4 || v6_address_bytes.length() != 16 || - in.preferred_address->stateless_reset_token.size() != - kStatelessResetTokenLength) { - QUIC_BUG(quic_bug_10743_12) << "Bad lengths " << *in.preferred_address; - return false; - } - const uint64_t preferred_address_length = - v4_address_bytes.length() + /* IPv4 port */ sizeof(uint16_t) + - v6_address_bytes.length() + /* IPv6 port */ sizeof(uint16_t) + - /* connection ID length byte */ sizeof(uint8_t) + - in.preferred_address->connection_id.length() + - in.preferred_address->stateless_reset_token.size(); - if (!writer.WriteVarInt62(TransportParameters::kPreferredAddress) || - !writer.WriteVarInt62( - /* transport parameter length */ preferred_address_length) || - !writer.WriteStringPiece(v4_address_bytes) || - !writer.WriteUInt16(in.preferred_address->ipv4_socket_address.port()) || - !writer.WriteStringPiece(v6_address_bytes) || - !writer.WriteUInt16(in.preferred_address->ipv6_socket_address.port()) || - !writer.WriteUInt8(in.preferred_address->connection_id.length()) || - !writer.WriteBytes(in.preferred_address->connection_id.data(), - in.preferred_address->connection_id.length()) || - !writer.WriteBytes( - in.preferred_address->stateless_reset_token.data(), - in.preferred_address->stateless_reset_token.size())) { - QUIC_BUG(quic_bug_10743_13) - << "Failed to write preferred_address for " << in; - return false; - } - } - - // initial_source_connection_id - if (in.initial_source_connection_id.has_value()) { - QuicConnectionId initial_source_connection_id = - in.initial_source_connection_id.value(); - if (!writer.WriteVarInt62( - TransportParameters::kInitialSourceConnectionId) || - !writer.WriteStringPieceVarInt62( - absl::string_view(initial_source_connection_id.data(), - initial_source_connection_id.length()))) { - QUIC_BUG(quic_bug_10743_14) - << "Failed to write initial_source_connection_id " - << initial_source_connection_id << " for " << in; - return false; - } - } - - // retry_source_connection_id - if (in.retry_source_connection_id.has_value()) { - QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); - QuicConnectionId retry_source_connection_id = - in.retry_source_connection_id.value(); - if (!writer.WriteVarInt62(TransportParameters::kRetrySourceConnectionId) || - !writer.WriteStringPieceVarInt62( - absl::string_view(retry_source_connection_id.data(), - retry_source_connection_id.length()))) { - QUIC_BUG(quic_bug_10743_15) - << "Failed to write retry_source_connection_id " - << retry_source_connection_id << " for " << in; - return false; - } - } - - // Google-specific connection options. - if (in.google_connection_options.has_value()) { - static_assert(sizeof(in.google_connection_options.value().front()) == 4, - "bad size"); - uint64_t connection_options_length = - in.google_connection_options.value().size() * 4; - if (!writer.WriteVarInt62(TransportParameters::kGoogleConnectionOptions) || - !writer.WriteVarInt62( - /* transport parameter length */ connection_options_length)) { - QUIC_BUG(quic_bug_10743_16) - << "Failed to write google_connection_options of length " - << connection_options_length << " for " << in; - return false; - } - for (const QuicTag& connection_option : - in.google_connection_options.value()) { - if (!writer.WriteTag(connection_option)) { - QUIC_BUG(quic_bug_10743_17) - << "Failed to write google_connection_option " - << QuicTagToString(connection_option) << " for " << in; - return false; - } - } - } - - // Google-specific user agent identifier. - if (in.user_agent_id.has_value()) { - if (!writer.WriteVarInt62(TransportParameters::kGoogleUserAgentId) || - !writer.WriteStringPieceVarInt62(in.user_agent_id.value())) { - QUIC_BUG(quic_bug_10743_18) - << "Failed to write Google user agent ID \"" - << in.user_agent_id.value() << "\" for " << in; - return false; - } + if (in.legacy_version_information.has_value()) { + max_transport_param_length += + sizeof(in.legacy_version_information.value().version) + + 1 /* versions length */ + + in.legacy_version_information.value().supported_versions.size() * + sizeof(QuicVersionLabel); } + // version_information. + if (in.version_information.has_value()) { + max_transport_param_length += + sizeof(in.version_information.value().chosen_version) + + // Add one for the added GREASE version. + (in.version_information.value().other_versions.size() + 1) * + sizeof(QuicVersionLabel); + } + + // Add a random GREASE transport parameter, as defined in the + // "Reserved Transport Parameters" section of RFC 9000. + // This forces receivers to support unexpected input. + QuicRandom* random = QuicRandom::GetInstance(); + // Transport parameter identifiers are 62 bits long so we need to + // ensure that the output of the computation below fits in 62 bits. + uint64_t grease_id64 = random->RandUint64() % ((1ULL << 62) - 31); + // Make sure grease_id % 31 == 27. Note that this is not uniformely + // distributed but is acceptable since no security depends on this + // randomness. + grease_id64 = (grease_id64 / 31) * 31 + 27; + TransportParameters::TransportParameterId grease_id = + static_cast(grease_id64); + const size_t grease_length = random->RandUint64() % kMaxGreaseLength; + QUICHE_DCHECK_GE(kMaxGreaseLength, grease_length); + char grease_contents[kMaxGreaseLength]; + random->RandBytes(grease_contents, grease_length); + custom_parameters[grease_id] = std::string(grease_contents, grease_length); - // Google-specific indicator for key update not yet supported. - if (in.key_update_not_yet_supported) { - if (!writer.WriteVarInt62( - TransportParameters::kGoogleKeyUpdateNotYetSupported) || - !writer.WriteVarInt62(/* transport parameter length */ 0)) { - QUIC_BUG(quic_bug_10743_19) - << "Failed to write key_update_not_yet_supported for " << in; - return false; - } + // Custom parameters. + for (const auto& kv : custom_parameters) { + max_transport_param_length += kTypeAndValueLength + kv.second.length(); + parameter_ids.push_back(kv.first); } - // Google-specific version extension. - static_assert(sizeof(QuicVersionLabel) == sizeof(uint32_t), "bad length"); - uint64_t google_version_length = sizeof(in.version); - if (in.perspective == Perspective::IS_SERVER) { - google_version_length += - /* versions length */ sizeof(uint8_t) + - sizeof(QuicVersionLabel) * in.supported_versions.size(); - } - if (!writer.WriteVarInt62(TransportParameters::kGoogleQuicVersion) || - !writer.WriteVarInt62( - /* transport parameter length */ google_version_length) || - !writer.WriteUInt32(in.version)) { - QUIC_BUG(quic_bug_10743_20) - << "Failed to write Google version extension for " << in; - return false; - } - if (in.perspective == Perspective::IS_SERVER) { - if (!writer.WriteUInt8(sizeof(QuicVersionLabel) * - in.supported_versions.size())) { - QUIC_BUG(quic_bug_10743_21) - << "Failed to write versions length for " << in; - return false; - } - for (QuicVersionLabel version_label : in.supported_versions) { - if (!writer.WriteUInt32(version_label)) { - QUIC_BUG(quic_bug_10743_22) - << "Failed to write supported version for " << in; - return false; - } - } + // Randomize order of sent transport parameters by walking the array + // backwards and swapping each element with a random earlier one. + for (size_t i = parameter_ids.size() - 1; i > 0; i--) { + std::swap(parameter_ids[i], + parameter_ids[random->InsecureRandUint64() % (i + 1)]); } - for (const auto& kv : in.custom_parameters) { - const TransportParameters::TransportParameterId param_id = kv.first; - if (param_id % 31 == 27) { - // See the "Reserved Transport Parameters" section of - // draft-ietf-quic-transport. - QUIC_BUG(quic_bug_10743_23) - << "Serializing custom_parameters with GREASE ID " << param_id - << " is not allowed"; - return false; - } - if (!writer.WriteVarInt62(param_id) || - !writer.WriteStringPieceVarInt62(kv.second)) { - QUIC_BUG(quic_bug_10743_24) - << "Failed to write custom parameter " << param_id; - return false; - } - } + out->resize(max_transport_param_length); + QuicDataWriter writer(out->size(), reinterpret_cast(out->data())); - { - // Add a random GREASE transport parameter, as defined in the - // "Reserved Transport Parameters" section of draft-ietf-quic-transport. - // https://quicwg.org/base-drafts/draft-ietf-quic-transport.html - // This forces receivers to support unexpected input. - QuicRandom* random = QuicRandom::GetInstance(); - // Transport parameter identifiers are 62 bits long so we need to ensure - // that the output of the computation below fits in 62 bits. - uint64_t grease_id64 = random->RandUint64() % ((1ULL << 62) - 31); - // Make sure grease_id % 31 == 27. Note that this is not uniformely - // distributed but is acceptable since no security depends on this - // randomness. - grease_id64 = (grease_id64 / 31) * 31 + 27; - TransportParameters::TransportParameterId grease_id = - static_cast(grease_id64); - const size_t grease_length = random->RandUint64() % kMaxGreaseLength; - QUICHE_DCHECK_GE(kMaxGreaseLength, grease_length); - char grease_contents[kMaxGreaseLength]; - random->RandBytes(grease_contents, grease_length); - if (!writer.WriteVarInt62(grease_id) || - !writer.WriteStringPieceVarInt62( - absl::string_view(grease_contents, grease_length))) { - QUIC_BUG(quic_bug_10743_25) << "Failed to write GREASE parameter " - << TransportParameterIdToString(grease_id); - return false; + for (TransportParameters::TransportParameterId parameter_id : parameter_ids) { + switch (parameter_id) { + // original_destination_connection_id + case TransportParameters::kOriginalDestinationConnectionId: { + if (in.original_destination_connection_id.has_value()) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); + QuicConnectionId original_destination_connection_id = + in.original_destination_connection_id.value(); + if (!writer.WriteVarInt62( + TransportParameters::kOriginalDestinationConnectionId) || + !writer.WriteStringPieceVarInt62(absl::string_view( + original_destination_connection_id.data(), + original_destination_connection_id.length()))) { + QUIC_BUG(Failed to write original_destination_connection_id) + << "Failed to write original_destination_connection_id " + << original_destination_connection_id << " for " << in; + return false; + } + } + } break; + // max_idle_timeout + case TransportParameters::kMaxIdleTimeout: { + if (!in.max_idle_timeout_ms.Write(&writer)) { + QUIC_BUG(Failed to write idle_timeout) + << "Failed to write idle_timeout for " << in; + return false; + } + } break; + // stateless_reset_token + case TransportParameters::kStatelessResetToken: { + if (!in.stateless_reset_token.empty()) { + QUICHE_DCHECK_EQ(kStatelessResetTokenLength, + in.stateless_reset_token.size()); + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); + if (!writer.WriteVarInt62( + TransportParameters::kStatelessResetToken) || + !writer.WriteStringPieceVarInt62( + absl::string_view(reinterpret_cast( + in.stateless_reset_token.data()), + in.stateless_reset_token.size()))) { + QUIC_BUG(Failed to write stateless_reset_token) + << "Failed to write stateless_reset_token of length " + << in.stateless_reset_token.size() << " for " << in; + return false; + } + } + } break; + // max_udp_payload_size + case TransportParameters::kMaxPacketSize: { + if (!in.max_udp_payload_size.Write(&writer)) { + QUIC_BUG(Failed to write max_udp_payload_size) + << "Failed to write max_udp_payload_size for " << in; + return false; + } + } break; + // initial_max_data + case TransportParameters::kInitialMaxData: { + if (!in.initial_max_data.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_data) + << "Failed to write initial_max_data for " << in; + return false; + } + } break; + // initial_max_stream_data_bidi_local + case TransportParameters::kInitialMaxStreamDataBidiLocal: { + if (!in.initial_max_stream_data_bidi_local.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_stream_data_bidi_local) + << "Failed to write initial_max_stream_data_bidi_local for " + << in; + return false; + } + } break; + // initial_max_stream_data_bidi_remote + case TransportParameters::kInitialMaxStreamDataBidiRemote: { + if (!in.initial_max_stream_data_bidi_remote.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_stream_data_bidi_remote) + << "Failed to write initial_max_stream_data_bidi_remote for " + << in; + return false; + } + } break; + // initial_max_stream_data_uni + case TransportParameters::kInitialMaxStreamDataUni: { + if (!in.initial_max_stream_data_uni.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_stream_data_uni) + << "Failed to write initial_max_stream_data_uni for " << in; + return false; + } + } break; + // initial_max_streams_bidi + case TransportParameters::kInitialMaxStreamsBidi: { + if (!in.initial_max_streams_bidi.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_streams_bidi) + << "Failed to write initial_max_streams_bidi for " << in; + return false; + } + } break; + // initial_max_streams_uni + case TransportParameters::kInitialMaxStreamsUni: { + if (!in.initial_max_streams_uni.Write(&writer)) { + QUIC_BUG(Failed to write initial_max_streams_uni) + << "Failed to write initial_max_streams_uni for " << in; + return false; + } + } break; + // ack_delay_exponent + case TransportParameters::kAckDelayExponent: { + if (!in.ack_delay_exponent.Write(&writer)) { + QUIC_BUG(Failed to write ack_delay_exponent) + << "Failed to write ack_delay_exponent for " << in; + return false; + } + } break; + // max_ack_delay + case TransportParameters::kMaxAckDelay: { + if (!in.max_ack_delay.Write(&writer)) { + QUIC_BUG(Failed to write max_ack_delay) + << "Failed to write max_ack_delay for " << in; + return false; + } + } break; + // min_ack_delay_us + case TransportParameters::kMinAckDelay: { + if (!in.min_ack_delay_us.Write(&writer)) { + QUIC_BUG(Failed to write min_ack_delay_us) + << "Failed to write min_ack_delay_us for " << in; + return false; + } + } break; + // active_connection_id_limit + case TransportParameters::kActiveConnectionIdLimit: { + if (!in.active_connection_id_limit.Write(&writer)) { + QUIC_BUG(Failed to write active_connection_id_limit) + << "Failed to write active_connection_id_limit for " << in; + return false; + } + } break; + // max_datagram_frame_size + case TransportParameters::kMaxDatagramFrameSize: { + if (!in.max_datagram_frame_size.Write(&writer)) { + QUIC_BUG(Failed to write max_datagram_frame_size) + << "Failed to write max_datagram_frame_size for " << in; + return false; + } + } break; + // initial_round_trip_time_us + case TransportParameters::kInitialRoundTripTime: { + if (!in.initial_round_trip_time_us.Write(&writer)) { + QUIC_BUG(Failed to write initial_round_trip_time_us) + << "Failed to write initial_round_trip_time_us for " << in; + return false; + } + } break; + // disable_active_migration + case TransportParameters::kDisableActiveMigration: { + if (in.disable_active_migration) { + if (!writer.WriteVarInt62( + TransportParameters::kDisableActiveMigration) || + !writer.WriteVarInt62(/* transport parameter length */ 0)) { + QUIC_BUG(Failed to write disable_active_migration) + << "Failed to write disable_active_migration for " << in; + return false; + } + } + } break; + // preferred_address + case TransportParameters::kPreferredAddress: { + if (in.preferred_address) { + std::string v4_address_bytes = + in.preferred_address->ipv4_socket_address.host().ToPackedString(); + std::string v6_address_bytes = + in.preferred_address->ipv6_socket_address.host().ToPackedString(); + if (v4_address_bytes.length() != 4 || + v6_address_bytes.length() != 16 || + in.preferred_address->stateless_reset_token.size() != + kStatelessResetTokenLength) { + QUIC_BUG(quic_bug_10743_12) + << "Bad lengths " << *in.preferred_address; + return false; + } + const uint64_t preferred_address_length = + v4_address_bytes.length() + /* IPv4 port */ sizeof(uint16_t) + + v6_address_bytes.length() + /* IPv6 port */ sizeof(uint16_t) + + /* connection ID length byte */ sizeof(uint8_t) + + in.preferred_address->connection_id.length() + + in.preferred_address->stateless_reset_token.size(); + if (!writer.WriteVarInt62(TransportParameters::kPreferredAddress) || + !writer.WriteVarInt62( + /* transport parameter length */ preferred_address_length) || + !writer.WriteStringPiece(v4_address_bytes) || + !writer.WriteUInt16( + in.preferred_address->ipv4_socket_address.port()) || + !writer.WriteStringPiece(v6_address_bytes) || + !writer.WriteUInt16( + in.preferred_address->ipv6_socket_address.port()) || + !writer.WriteUInt8( + in.preferred_address->connection_id.length()) || + !writer.WriteBytes( + in.preferred_address->connection_id.data(), + in.preferred_address->connection_id.length()) || + !writer.WriteBytes( + in.preferred_address->stateless_reset_token.data(), + in.preferred_address->stateless_reset_token.size())) { + QUIC_BUG(Failed to write preferred_address) + << "Failed to write preferred_address for " << in; + return false; + } + } + } break; + // initial_source_connection_id + case TransportParameters::kInitialSourceConnectionId: { + if (in.initial_source_connection_id.has_value()) { + QuicConnectionId initial_source_connection_id = + in.initial_source_connection_id.value(); + if (!writer.WriteVarInt62( + TransportParameters::kInitialSourceConnectionId) || + !writer.WriteStringPieceVarInt62( + absl::string_view(initial_source_connection_id.data(), + initial_source_connection_id.length()))) { + QUIC_BUG(Failed to write initial_source_connection_id) + << "Failed to write initial_source_connection_id " + << initial_source_connection_id << " for " << in; + return false; + } + } + } break; + // retry_source_connection_id + case TransportParameters::kRetrySourceConnectionId: { + if (in.retry_source_connection_id.has_value()) { + QUICHE_DCHECK_EQ(Perspective::IS_SERVER, in.perspective); + QuicConnectionId retry_source_connection_id = + in.retry_source_connection_id.value(); + if (!writer.WriteVarInt62( + TransportParameters::kRetrySourceConnectionId) || + !writer.WriteStringPieceVarInt62( + absl::string_view(retry_source_connection_id.data(), + retry_source_connection_id.length()))) { + QUIC_BUG(Failed to write retry_source_connection_id) + << "Failed to write retry_source_connection_id " + << retry_source_connection_id << " for " << in; + return false; + } + } + } break; + // Google-specific connection options. + case TransportParameters::kGoogleConnectionOptions: { + if (in.google_connection_options.has_value()) { + static_assert( + sizeof(in.google_connection_options.value().front()) == 4, + "bad size"); + uint64_t connection_options_length = + in.google_connection_options.value().size() * 4; + if (!writer.WriteVarInt62( + TransportParameters::kGoogleConnectionOptions) || + !writer.WriteVarInt62( + /* transport parameter length */ connection_options_length)) { + QUIC_BUG(Failed to write google_connection_options) + << "Failed to write google_connection_options of length " + << connection_options_length << " for " << in; + return false; + } + for (const QuicTag& connection_option : + in.google_connection_options.value()) { + if (!writer.WriteTag(connection_option)) { + QUIC_BUG(Failed to write google_connection_option) + << "Failed to write google_connection_option " + << QuicTagToString(connection_option) << " for " << in; + return false; + } + } + } + } break; + // Google-specific user agent identifier. + case TransportParameters::kGoogleUserAgentId: { + if (in.user_agent_id.has_value()) { + if (!writer.WriteVarInt62(TransportParameters::kGoogleUserAgentId) || + !writer.WriteStringPieceVarInt62(in.user_agent_id.value())) { + QUIC_BUG(Failed to write Google user agent ID) + << "Failed to write Google user agent ID \"" + << in.user_agent_id.value() << "\" for " << in; + return false; + } + } + } break; + // Google-specific indicator for key update not yet supported. + case TransportParameters::kGoogleKeyUpdateNotYetSupported: { + if (in.key_update_not_yet_supported) { + if (!writer.WriteVarInt62( + TransportParameters::kGoogleKeyUpdateNotYetSupported) || + !writer.WriteVarInt62(/* transport parameter length */ 0)) { + QUIC_BUG(Failed to write key_update_not_yet_supported) + << "Failed to write key_update_not_yet_supported for " << in; + return false; + } + } + } break; + // Google-specific version extension. + case TransportParameters::kGoogleQuicVersion: { + if (!in.legacy_version_information.has_value()) { + break; + } + static_assert(sizeof(QuicVersionLabel) == sizeof(uint32_t), + "bad length"); + uint64_t google_version_length = + sizeof(in.legacy_version_information.value().version); + if (in.perspective == Perspective::IS_SERVER) { + google_version_length += + /* versions length */ sizeof(uint8_t) + + sizeof(QuicVersionLabel) * in.legacy_version_information.value() + .supported_versions.size(); + } + if (!writer.WriteVarInt62(TransportParameters::kGoogleQuicVersion) || + !writer.WriteVarInt62( + /* transport parameter length */ google_version_length) || + !writer.WriteUInt32( + in.legacy_version_information.value().version)) { + QUIC_BUG(Failed to write Google version extension) + << "Failed to write Google version extension for " << in; + return false; + } + if (in.perspective == Perspective::IS_SERVER) { + if (!writer.WriteUInt8(sizeof(QuicVersionLabel) * + in.legacy_version_information.value() + .supported_versions.size())) { + QUIC_BUG(Failed to write versions length) + << "Failed to write versions length for " << in; + return false; + } + for (QuicVersionLabel version_label : + in.legacy_version_information.value().supported_versions) { + if (!writer.WriteUInt32(version_label)) { + QUIC_BUG(Failed to write supported version) + << "Failed to write supported version for " << in; + return false; + } + } + } + } break; + // version_information. + case TransportParameters::kVersionInformation: { + if (!in.version_information.has_value()) { + break; + } + static_assert(sizeof(QuicVersionLabel) == sizeof(uint32_t), + "bad length"); + QuicVersionLabelVector other_versions = + in.version_information.value().other_versions; + // Insert one GREASE version at a random index. + const size_t grease_index = + random->InsecureRandUint64() % (other_versions.size() + 1); + other_versions.insert( + other_versions.begin() + grease_index, + CreateQuicVersionLabel(QuicVersionReservedForNegotiation())); + const uint64_t version_information_length = + sizeof(in.version_information.value().chosen_version) + + sizeof(QuicVersionLabel) * other_versions.size(); + if (!writer.WriteVarInt62(TransportParameters::kVersionInformation) || + !writer.WriteVarInt62( + /* transport parameter length */ version_information_length) || + !writer.WriteUInt32( + in.version_information.value().chosen_version)) { + QUIC_BUG(Failed to write chosen version) + << "Failed to write chosen version for " << in; + return false; + } + for (QuicVersionLabel version_label : other_versions) { + if (!writer.WriteUInt32(version_label)) { + QUIC_BUG(Failed to write other version) + << "Failed to write other version for " << in; + return false; + } + } + } break; + // Custom parameters and GREASE. + default: { + auto it = custom_parameters.find(parameter_id); + if (it == custom_parameters.end()) { + QUIC_BUG(Unknown parameter) << "Unknown parameter " << parameter_id; + return false; + } + if (!writer.WriteVarInt62(parameter_id) || + !writer.WriteStringPieceVarInt62(it->second)) { + QUIC_BUG(Failed to write custom parameter) + << "Failed to write custom parameter " << parameter_id; + return false; + } + } break; } } @@ -1175,6 +1471,22 @@ bool ParseTransportParameters(ParsedQuicVersion version, } } break; case TransportParameters::kGoogleUserAgentId: + if (GetQuicReloadableFlag(quic_ignore_user_agent_transport_parameter)) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_ignore_user_agent_transport_parameter); + // This is a copy of the default switch statement below. + // TODO(dschinazi) remove this case entirely when deprecating the + // quic_ignore_user_agent_transport_parameter flag. + if (out->custom_parameters.find(param_id) != + out->custom_parameters.end()) { + *error_details = "Received a second unknown parameter" + + TransportParameterIdToString(param_id); + return false; + } + out->custom_parameters[param_id] = + std::string(value_reader.ReadRemainingPayload()); + break; + } if (out->user_agent_id.has_value()) { *error_details = "Received a second user_agent_id"; return false; @@ -1182,6 +1494,24 @@ bool ParseTransportParameters(ParsedQuicVersion version, out->user_agent_id = std::string(value_reader.ReadRemainingPayload()); break; case TransportParameters::kGoogleKeyUpdateNotYetSupported: + if (GetQuicReloadableFlag(quic_ignore_key_update_not_yet_supported)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_ignore_key_update_not_yet_supported, + 1, 2); + QUIC_CODE_COUNT(quic_ignore_key_update_not_yet_supported_ignored); + // This is a copy of the default switch statement below. + // TODO(dschinazi) remove this case entirely when deprecating the + // quic_ignore_key_update_not_yet_supported flag. + if (out->custom_parameters.find(param_id) != + out->custom_parameters.end()) { + *error_details = "Received a second unknown parameter" + + TransportParameterIdToString(param_id); + return false; + } + out->custom_parameters[param_id] = + std::string(value_reader.ReadRemainingPayload()); + break; + } + QUIC_CODE_COUNT(quic_ignore_key_update_not_yet_supported_received); if (out->key_update_not_yet_supported) { *error_details = "Received a second key_update_not_yet_supported"; return false; @@ -1189,7 +1519,12 @@ bool ParseTransportParameters(ParsedQuicVersion version, out->key_update_not_yet_supported = true; break; case TransportParameters::kGoogleQuicVersion: { - if (!value_reader.ReadUInt32(&out->version)) { + if (!out->legacy_version_information.has_value()) { + out->legacy_version_information = + TransportParameters::LegacyVersionInformation(); + } + if (!value_reader.ReadUInt32( + &out->legacy_version_information.value().version)) { *error_details = "Failed to read Google version extension version"; return false; } @@ -1206,8 +1541,44 @@ bool ParseTransportParameters(ParsedQuicVersion version, *error_details = "Failed to parse Google supported version"; return false; } - out->supported_versions.push_back(version); + out->legacy_version_information.value() + .supported_versions.push_back(version); + } + } + } break; + case TransportParameters::kVersionInformation: { + if (!GetQuicReloadableFlag(quic_version_information)) { + // This duplicates the default case and will be removed when this flag + // is deprecated. + if (out->custom_parameters.find(param_id) != + out->custom_parameters.end()) { + *error_details = "Received a second unknown parameter" + + TransportParameterIdToString(param_id); + return false; + } + out->custom_parameters[param_id] = + std::string(value_reader.ReadRemainingPayload()); + break; + } + QUIC_RELOADABLE_FLAG_COUNT_N(quic_version_information, 2, 2); + if (out->version_information.has_value()) { + *error_details = "Received a second version_information"; + return false; + } + out->version_information = TransportParameters::VersionInformation(); + if (!value_reader.ReadUInt32( + &out->version_information.value().chosen_version)) { + *error_details = "Failed to read chosen version"; + return false; + } + while (!value_reader.IsDoneReading()) { + QuicVersionLabel other_version; + if (!value_reader.ReadUInt32(&other_version)) { + *error_details = "Failed to parse other version"; + return false; } + out->version_information.value().other_versions.push_back( + other_version); } } break; case TransportParameters::kMinAckDelay: diff --git a/gquiche/quic/core/crypto/transport_parameters.h b/gquiche/quic/core/crypto/transport_parameters.h index 16bd6d2d..e3b4fdb9 100644 --- a/gquiche/quic/core/crypto/transport_parameters.h +++ b/gquiche/quic/core/crypto/transport_parameters.h @@ -110,6 +110,64 @@ struct QUIC_EXPORT_PRIVATE TransportParameters { const TransportParameters& params); }; + // LegacyVersionInformation represents the Google QUIC downgrade prevention + // mechanism ported to QUIC+TLS. It is exchanged using transport parameter ID + // 0x4752 and will eventually be deprecated in favor of + // draft-ietf-quic-version-negotiation. + struct QUIC_EXPORT_PRIVATE LegacyVersionInformation { + LegacyVersionInformation(); + LegacyVersionInformation(const LegacyVersionInformation& other) = default; + LegacyVersionInformation& operator=(const LegacyVersionInformation& other) = + default; + LegacyVersionInformation& operator=(LegacyVersionInformation&& other) = + default; + LegacyVersionInformation(LegacyVersionInformation&& other) = default; + ~LegacyVersionInformation() = default; + bool operator==(const LegacyVersionInformation& rhs) const; + bool operator!=(const LegacyVersionInformation& rhs) const; + // When sent by the client, |version| is the initial version offered by the + // client (before any version negotiation packets) for this connection. When + // sent by the server, |version| is the version that is in use. + QuicVersionLabel version; + + // When sent by the server, |supported_versions| contains a list of all + // versions that the server would send in a version negotiation packet. When + // sent by the client, this is empty. + QuicVersionLabelVector supported_versions; + + // Allows easily logging. + std::string ToString() const; + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, + const LegacyVersionInformation& legacy_version_information); + }; + + // Version information used for version downgrade prevention and compatible + // version negotiation. See draft-ietf-quic-version-negotiation-05. + struct QUIC_EXPORT_PRIVATE VersionInformation { + VersionInformation(); + VersionInformation(const VersionInformation& other) = default; + VersionInformation& operator=(const VersionInformation& other) = default; + VersionInformation& operator=(VersionInformation&& other) = default; + VersionInformation(VersionInformation&& other) = default; + ~VersionInformation() = default; + bool operator==(const VersionInformation& rhs) const; + bool operator!=(const VersionInformation& rhs) const; + + // Version that the sender has chosen to use on this connection. + QuicVersionLabel chosen_version; + + // When sent by the client, |other_versions| contains all the versions that + // this first flight is compatible with. When sent by the server, + // |other_versions| contains all of the versions supported by the server. + QuicVersionLabelVector other_versions; + + // Allows easily logging. + std::string ToString() const; + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const VersionInformation& version_information); + }; + TransportParameters(); TransportParameters(const TransportParameters& other); ~TransportParameters(); @@ -122,15 +180,12 @@ struct QUIC_EXPORT_PRIVATE TransportParameters { // the encrypted_extensions handshake message. Perspective perspective; - // When Perspective::IS_CLIENT, |version| is the initial version offered by - // the client (before any version negotiation packets) for this connection. - // When Perspective::IS_SERVER, |version| is the version that is in use. - QuicVersionLabel version; + // Google QUIC downgrade prevention mechanism sent over QUIC+TLS. + absl::optional legacy_version_information; - // |supported_versions| contains a list of all versions that the server would - // send in a version negotiation packet. It is not used if |perspective == - // Perspective::IS_CLIENT|. - QuicVersionLabelVector supported_versions; + // IETF downgrade prevention and compatible version negotiation, see + // draft-ietf-quic-version-negotiation. + absl::optional version_information; // The value of the Destination Connection ID field from the first // Initial packet sent by the client. diff --git a/gquiche/quic/core/crypto/transport_parameters_test.cc b/gquiche/quic/core/crypto/transport_parameters_test.cc index 7b90d8a3..6b0b4fff 100644 --- a/gquiche/quic/core/crypto/transport_parameters_test.cc +++ b/gquiche/quic/core/crypto/transport_parameters_test.cc @@ -99,6 +99,30 @@ CreateFakePreferredAddress() { preferred_address); } +TransportParameters::LegacyVersionInformation +CreateFakeLegacyVersionInformationClient() { + TransportParameters::LegacyVersionInformation legacy_version_information; + legacy_version_information.version = kFakeVersionLabel; + return legacy_version_information; +} + +TransportParameters::LegacyVersionInformation +CreateFakeLegacyVersionInformationServer() { + TransportParameters::LegacyVersionInformation legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + legacy_version_information.supported_versions.push_back(kFakeVersionLabel); + legacy_version_information.supported_versions.push_back(kFakeVersionLabel2); + return legacy_version_information; +} + +TransportParameters::VersionInformation CreateFakeVersionInformation() { + TransportParameters::VersionInformation version_information; + version_information.chosen_version = kFakeVersionLabel; + version_information.other_versions.push_back(kFakeVersionLabel); + version_information.other_versions.push_back(kFakeVersionLabel2); + return version_information; +} + QuicTagVector CreateFakeGoogleConnectionOptions() { return {kALPN, MakeQuicTag('E', 'F', 'G', 0x00), MakeQuicTag('H', 'I', 'J', 0xff)}; @@ -119,6 +143,18 @@ void RemoveGreaseParameters(TransportParameters* params) { for (TransportParameters::TransportParameterId param_id : grease_params) { params->custom_parameters.erase(param_id); } + // Remove all GREASE versions from version_information.other_versions. + if (params->version_information.has_value()) { + QuicVersionLabelVector& other_versions = + params->version_information.value().other_versions; + for (auto it = other_versions.begin(); it != other_versions.end();) { + if ((*it & 0x0f0f0f0f) == 0x0a0a0a0a) { + it = other_versions.erase(it); + } else { + ++it; + } + } + } } } // namespace @@ -145,8 +181,12 @@ TEST_P(TransportParametersTest, Comparator) { EXPECT_FALSE(orig_params == new_params); EXPECT_TRUE(orig_params != new_params); new_params.perspective = Perspective::IS_CLIENT; - orig_params.version = kFakeVersionLabel; - new_params.version = kFakeVersionLabel; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + new_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.version_information = CreateFakeVersionInformation(); + new_params.version_information = CreateFakeVersionInformation(); orig_params.disable_active_migration = true; new_params.disable_active_migration = true; EXPECT_EQ(orig_params, new_params); @@ -154,13 +194,16 @@ TEST_P(TransportParametersTest, Comparator) { EXPECT_FALSE(orig_params != new_params); // Test comparison on vectors. - orig_params.supported_versions.push_back(kFakeVersionLabel); - new_params.supported_versions.push_back(kFakeVersionLabel2); + orig_params.legacy_version_information.value().supported_versions.push_back( + kFakeVersionLabel); + new_params.legacy_version_information.value().supported_versions.push_back( + kFakeVersionLabel2); EXPECT_NE(orig_params, new_params); EXPECT_FALSE(orig_params == new_params); EXPECT_TRUE(orig_params != new_params); - new_params.supported_versions.pop_back(); - new_params.supported_versions.push_back(kFakeVersionLabel); + new_params.legacy_version_information.value().supported_versions.pop_back(); + new_params.legacy_version_information.value().supported_versions.push_back( + kFakeVersionLabel); orig_params.stateless_reset_token = CreateStatelessResetTokenForTest(); new_params.stateless_reset_token = CreateStatelessResetTokenForTest(); EXPECT_EQ(orig_params, new_params); @@ -219,9 +262,9 @@ TEST_P(TransportParametersTest, Comparator) { TEST_P(TransportParametersTest, CopyConstructor) { TransportParameters orig_params; orig_params.perspective = Perspective::IS_CLIENT; - orig_params.version = kFakeVersionLabel; - orig_params.supported_versions.push_back(kFakeVersionLabel); - orig_params.supported_versions.push_back(kFakeVersionLabel2); + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.version_information = CreateFakeVersionInformation(); orig_params.original_destination_connection_id = CreateFakeOriginalDestinationConnectionId(); orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); @@ -260,7 +303,11 @@ TEST_P(TransportParametersTest, CopyConstructor) { TEST_P(TransportParametersTest, RoundTripClient) { TransportParameters orig_params; orig_params.perspective = Perspective::IS_CLIENT; - orig_params.version = kFakeVersionLabel; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + if (GetQuicReloadableFlag(quic_version_information)) { + orig_params.version_information = CreateFakeVersionInformation(); + } orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); orig_params.initial_max_data.set_value(kFakeInitialMaxData); @@ -282,8 +329,12 @@ TEST_P(TransportParametersTest, RoundTripClient) { CreateFakeInitialSourceConnectionId(); orig_params.initial_round_trip_time_us.set_value(kFakeInitialRoundTripTime); orig_params.google_connection_options = CreateFakeGoogleConnectionOptions(); - orig_params.user_agent_id = CreateFakeUserAgentId(); - orig_params.key_update_not_yet_supported = kFakeKeyUpdateNotYetSupported; + if (!GetQuicReloadableFlag(quic_ignore_user_agent_transport_parameter)) { + orig_params.user_agent_id = CreateFakeUserAgentId(); + } + if (!GetQuicReloadableFlag(quic_ignore_key_update_not_yet_supported)) { + orig_params.key_update_not_yet_supported = kFakeKeyUpdateNotYetSupported; + } orig_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; orig_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; @@ -304,9 +355,11 @@ TEST_P(TransportParametersTest, RoundTripClient) { TEST_P(TransportParametersTest, RoundTripServer) { TransportParameters orig_params; orig_params.perspective = Perspective::IS_SERVER; - orig_params.version = kFakeVersionLabel; - orig_params.supported_versions.push_back(kFakeVersionLabel); - orig_params.supported_versions.push_back(kFakeVersionLabel2); + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationServer(); + if (GetQuicReloadableFlag(quic_version_information)) { + orig_params.version_information = CreateFakeVersionInformation(); + } orig_params.original_destination_connection_id = CreateFakeOriginalDestinationConnectionId(); orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); @@ -442,7 +495,8 @@ TEST_P(TransportParametersTest, AreValid) { TEST_P(TransportParametersTest, NoClientParamsWithStatelessResetToken) { TransportParameters orig_params; orig_params.perspective = Perspective::IS_CLIENT; - orig_params.version = kFakeVersionLabel; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); orig_params.stateless_reset_token = CreateStatelessResetTokenForTest(); orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); @@ -535,6 +589,12 @@ TEST_P(TransportParametersTest, ParseClientParams) { 0x80, 0x00, 0x47, 0x52, // parameter id 0x04, // length 0x01, 0x23, 0x45, 0x67, // initial version + // version_information + 0x80, 0xFF, 0x73, 0xDB, // parameter id + 0x0C, // length + 0x01, 0x23, 0x45, 0x67, // chosen version + 0x01, 0x23, 0x45, 0x67, // other version 1 + 0x89, 0xab, 0xcd, 0xef, // other version 2 }; // clang-format on const uint8_t* client_params = @@ -548,8 +608,16 @@ TEST_P(TransportParametersTest, ParseClientParams) { << error_details; EXPECT_TRUE(error_details.empty()); EXPECT_EQ(Perspective::IS_CLIENT, new_params.perspective); - EXPECT_EQ(kFakeVersionLabel, new_params.version); - EXPECT_TRUE(new_params.supported_versions.empty()); + ASSERT_TRUE(new_params.legacy_version_information.has_value()); + EXPECT_EQ(kFakeVersionLabel, + new_params.legacy_version_information.value().version); + EXPECT_TRUE( + new_params.legacy_version_information.value().supported_versions.empty()); + if (GetQuicReloadableFlag(quic_version_information)) { + ASSERT_TRUE(new_params.version_information.has_value()); + EXPECT_EQ(new_params.version_information.value(), + CreateFakeVersionInformation()); + } EXPECT_FALSE(new_params.original_destination_connection_id.has_value()); EXPECT_EQ(kFakeIdleTimeoutMilliseconds, new_params.max_idle_timeout_ms.value()); @@ -581,9 +649,15 @@ TEST_P(TransportParametersTest, ParseClientParams) { ASSERT_TRUE(new_params.google_connection_options.has_value()); EXPECT_EQ(CreateFakeGoogleConnectionOptions(), new_params.google_connection_options.value()); - ASSERT_TRUE(new_params.user_agent_id.has_value()); - EXPECT_EQ(CreateFakeUserAgentId(), new_params.user_agent_id.value()); - EXPECT_TRUE(new_params.key_update_not_yet_supported); + if (!GetQuicReloadableFlag(quic_ignore_user_agent_transport_parameter)) { + ASSERT_TRUE(new_params.user_agent_id.has_value()); + EXPECT_EQ(CreateFakeUserAgentId(), new_params.user_agent_id.value()); + } else { + EXPECT_FALSE(new_params.user_agent_id.has_value()); + } + if (!GetQuicReloadableFlag(quic_ignore_key_update_not_yet_supported)) { + EXPECT_TRUE(new_params.key_update_not_yet_supported); + } } TEST_P(TransportParametersTest, @@ -780,6 +854,12 @@ TEST_P(TransportParametersTest, ParseServerParams) { 0x08, // length of supported versions array 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + // version_information + 0x80, 0xFF, 0x73, 0xDB, // parameter id + 0x0C, // length + 0x01, 0x23, 0x45, 0x67, // chosen version + 0x01, 0x23, 0x45, 0x67, // other version 1 + 0x89, 0xab, 0xcd, 0xef, // other version 2 }; // clang-format on const uint8_t* server_params = @@ -793,10 +873,23 @@ TEST_P(TransportParametersTest, ParseServerParams) { << error_details; EXPECT_TRUE(error_details.empty()); EXPECT_EQ(Perspective::IS_SERVER, new_params.perspective); - EXPECT_EQ(kFakeVersionLabel, new_params.version); - EXPECT_EQ(2u, new_params.supported_versions.size()); - EXPECT_EQ(kFakeVersionLabel, new_params.supported_versions[0]); - EXPECT_EQ(kFakeVersionLabel2, new_params.supported_versions[1]); + ASSERT_TRUE(new_params.legacy_version_information.has_value()); + EXPECT_EQ(kFakeVersionLabel, + new_params.legacy_version_information.value().version); + ASSERT_EQ( + 2u, + new_params.legacy_version_information.value().supported_versions.size()); + EXPECT_EQ( + kFakeVersionLabel, + new_params.legacy_version_information.value().supported_versions[0]); + EXPECT_EQ( + kFakeVersionLabel2, + new_params.legacy_version_information.value().supported_versions[1]); + if (GetQuicReloadableFlag(quic_version_information)) { + ASSERT_TRUE(new_params.version_information.has_value()); + EXPECT_EQ(new_params.version_information.value(), + CreateFakeVersionInformation()); + } ASSERT_TRUE(new_params.original_destination_connection_id.has_value()); EXPECT_EQ(CreateFakeOriginalDestinationConnectionId(), new_params.original_destination_connection_id.value()); @@ -841,7 +934,9 @@ TEST_P(TransportParametersTest, ParseServerParams) { EXPECT_EQ(CreateFakeGoogleConnectionOptions(), new_params.google_connection_options.value()); EXPECT_FALSE(new_params.user_agent_id.has_value()); - EXPECT_TRUE(new_params.key_update_not_yet_supported); + if (!GetQuicReloadableFlag(quic_ignore_key_update_not_yet_supported)) { + EXPECT_TRUE(new_params.key_update_not_yet_supported); + } } TEST_P(TransportParametersTest, ParseServerParametersRepeated) { @@ -915,7 +1010,8 @@ TEST_P(TransportParametersTest, VeryLongCustomParameter) { std::string custom_value(70000, '?'); TransportParameters orig_params; orig_params.perspective = Perspective::IS_CLIENT; - orig_params.version = kFakeVersionLabel; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); orig_params.custom_parameters[kCustomParameter1] = custom_value; std::vector serialized; @@ -932,13 +1028,59 @@ TEST_P(TransportParametersTest, VeryLongCustomParameter) { EXPECT_EQ(new_params, orig_params); } +TEST_P(TransportParametersTest, SerializationOrderIsRandom) { + TransportParameters orig_params; + orig_params.perspective = Perspective::IS_CLIENT; + orig_params.legacy_version_information = + CreateFakeLegacyVersionInformationClient(); + orig_params.max_idle_timeout_ms.set_value(kFakeIdleTimeoutMilliseconds); + orig_params.max_udp_payload_size.set_value(kMaxPacketSizeForTest); + orig_params.initial_max_data.set_value(kFakeInitialMaxData); + orig_params.initial_max_stream_data_bidi_local.set_value( + kFakeInitialMaxStreamDataBidiLocal); + orig_params.initial_max_stream_data_bidi_remote.set_value( + kFakeInitialMaxStreamDataBidiRemote); + orig_params.initial_max_stream_data_uni.set_value( + kFakeInitialMaxStreamDataUni); + orig_params.initial_max_streams_bidi.set_value(kFakeInitialMaxStreamsBidi); + orig_params.initial_max_streams_uni.set_value(kFakeInitialMaxStreamsUni); + orig_params.ack_delay_exponent.set_value(kAckDelayExponentForTest); + orig_params.max_ack_delay.set_value(kMaxAckDelayForTest); + orig_params.min_ack_delay_us.set_value(kMinAckDelayUsForTest); + orig_params.disable_active_migration = kFakeDisableMigration; + orig_params.active_connection_id_limit.set_value( + kActiveConnectionIdLimitForTest); + orig_params.initial_source_connection_id = + CreateFakeInitialSourceConnectionId(); + orig_params.initial_round_trip_time_us.set_value(kFakeInitialRoundTripTime); + orig_params.google_connection_options = CreateFakeGoogleConnectionOptions(); + orig_params.user_agent_id = CreateFakeUserAgentId(); + orig_params.key_update_not_yet_supported = kFakeKeyUpdateNotYetSupported; + orig_params.custom_parameters[kCustomParameter1] = kCustomParameter1Value; + orig_params.custom_parameters[kCustomParameter2] = kCustomParameter2Value; + + std::vector first_serialized; + ASSERT_TRUE( + SerializeTransportParameters(version_, orig_params, &first_serialized)); + // Test that a subsequent serialization is different from the first. + // Run in a loop to avoid a failure in the unlikely event that randomization + // produces the same result multiple times. + for (int i = 0; i < 1000; i++) { + std::vector serialized; + ASSERT_TRUE( + SerializeTransportParameters(version_, orig_params, &serialized)); + if (serialized != first_serialized) { + return; + } + } +} + class TransportParametersTicketSerializationTest : public QuicTest { protected: void SetUp() override { original_params_.perspective = Perspective::IS_SERVER; - original_params_.version = kFakeVersionLabel; - original_params_.supported_versions.push_back(kFakeVersionLabel); - original_params_.supported_versions.push_back(kFakeVersionLabel2); + original_params_.legacy_version_information = + CreateFakeLegacyVersionInformationServer(); original_params_.original_destination_connection_id = CreateFakeOriginalDestinationConnectionId(); original_params_.max_idle_timeout_ms.set_value( diff --git a/gquiche/quic/core/frames/quic_frame.cc b/gquiche/quic/core/frames/quic_frame.cc index cc4ff0d5..359bcd15 100644 --- a/gquiche/quic/core/frames/quic_frame.cc +++ b/gquiche/quic/core/frames/quic_frame.cc @@ -391,11 +391,9 @@ QuicFrame CopyQuicFrame(QuicBufferAllocator* allocator, copy.message_frame->data = frame.message_frame->data; copy.message_frame->message_length = frame.message_frame->message_length; for (const auto& slice : frame.message_frame->message_data) { - QuicUniqueBufferPtr buffer = - MakeUniqueBuffer(allocator, slice.length()); - memcpy(buffer.get(), slice.data(), slice.length()); + QuicBuffer buffer = QuicBuffer::Copy(allocator, slice.AsStringView()); copy.message_frame->message_data.push_back( - QuicMemSlice(std::move(buffer), slice.length())); + QuicMemSlice(std::move(buffer))); } break; case NEW_TOKEN_FRAME: diff --git a/gquiche/quic/core/frames/quic_frame.h b/gquiche/quic/core/frames/quic_frame.h index 7e5ae72f..faaa82db 100644 --- a/gquiche/quic/core/frames/quic_frame.h +++ b/gquiche/quic/core/frames/quic_frame.h @@ -9,6 +9,7 @@ #include #include +#include "absl/container/inlined_vector.h" #include "gquiche/quic/core/frames/quic_ack_frame.h" #include "gquiche/quic/core/frames/quic_ack_frequency_frame.h" #include "gquiche/quic/core/frames/quic_blocked_frame.h" @@ -133,7 +134,7 @@ static_assert(offsetof(QuicStreamFrame, type) == offsetof(QuicFrame, type), // A inline size of 1 is chosen to optimize the typical use case of // 1-stream-frame in QuicTransmissionInfo.retransmittable_frames. -using QuicFrames = QuicInlinedVector; +using QuicFrames = absl::InlinedVector; // Deletes all the sub-frames contained in |frames|. QUIC_EXPORT_PRIVATE void DeleteFrames(QuicFrames* frames); diff --git a/gquiche/quic/core/frames/quic_frames_test.cc b/gquiche/quic/core/frames/quic_frames_test.cc index 251bfbb1..2bda257c 100644 --- a/gquiche/quic/core/frames/quic_frames_test.cc +++ b/gquiche/quic/core/frames/quic_frames_test.cc @@ -85,7 +85,8 @@ TEST_F(QuicFramesTest, RstStreamFrameToString) { std::ostringstream stream; stream << rst_stream; EXPECT_EQ( - "{ control_frame_id: 1, stream_id: 1, byte_offset: 3, error_code: 6 }\n", + "{ control_frame_id: 1, stream_id: 1, byte_offset: 3, error_code: 6, " + "ietf_error_code: 0 }\n", stream.str()); EXPECT_TRUE(IsControlFrame(frame.type)); } @@ -543,10 +544,8 @@ TEST_F(QuicFramesTest, RemoveSmallestInterval) { TEST_F(QuicFramesTest, CopyQuicFrames) { QuicFrames frames; - SimpleBufferAllocator allocator; - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); QuicMessageFrame* message_frame = - new QuicMessageFrame(1, MakeSpan(&allocator, "message", &storage)); + new QuicMessageFrame(1, MemSliceFromString("message")); // Construct a frame list. for (uint8_t i = 0; i < NUM_FRAME_TYPES; ++i) { switch (i) { @@ -626,7 +625,7 @@ TEST_F(QuicFramesTest, CopyQuicFrames) { } } - QuicFrames copy = CopyQuicFrames(&allocator, frames); + QuicFrames copy = CopyQuicFrames(SimpleBufferAllocator::Get(), frames); ASSERT_EQ(NUM_FRAME_TYPES, copy.size()); for (uint8_t i = 0; i < NUM_FRAME_TYPES; ++i) { EXPECT_EQ(i, copy[i].type); diff --git a/gquiche/quic/core/frames/quic_message_frame.cc b/gquiche/quic/core/frames/quic_message_frame.cc index ae074552..36aad988 100644 --- a/gquiche/quic/core/frames/quic_message_frame.cc +++ b/gquiche/quic/core/frames/quic_message_frame.cc @@ -14,13 +14,18 @@ QuicMessageFrame::QuicMessageFrame(QuicMessageId message_id) : message_id(message_id), data(nullptr), message_length(0) {} QuicMessageFrame::QuicMessageFrame(QuicMessageId message_id, - QuicMemSliceSpan span) + absl::Span span) : message_id(message_id), data(nullptr), message_length(0) { - span.ConsumeAll([&](QuicMemSlice slice) { + for (QuicMemSlice& slice : span) { + if (slice.empty()) { + continue; + } message_length += slice.length(); message_data.push_back(std::move(slice)); - }); + } } +QuicMessageFrame::QuicMessageFrame(QuicMessageId message_id, QuicMemSlice slice) + : QuicMessageFrame(message_id, absl::MakeSpan(&slice, 1)) {} QuicMessageFrame::QuicMessageFrame(const char* data, QuicPacketLength length) : message_id(0), data(data), message_length(length) {} diff --git a/gquiche/quic/core/frames/quic_message_frame.h b/gquiche/quic/core/frames/quic_message_frame.h index 10a9cafe..44b51814 100644 --- a/gquiche/quic/core/frames/quic_message_frame.h +++ b/gquiche/quic/core/frames/quic_message_frame.h @@ -5,20 +5,22 @@ #ifndef QUICHE_QUIC_CORE_FRAMES_QUIC_MESSAGE_FRAME_H_ #define QUICHE_QUIC_CORE_FRAMES_QUIC_MESSAGE_FRAME_H_ +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_mem_slice.h" -#include "gquiche/quic/platform/api/quic_mem_slice_span.h" namespace quic { -using QuicMessageData = QuicInlinedVector; +using QuicMessageData = absl::InlinedVector; struct QUIC_EXPORT_PRIVATE QuicMessageFrame { QuicMessageFrame() = default; explicit QuicMessageFrame(QuicMessageId message_id); - QuicMessageFrame(QuicMessageId message_id, QuicMemSliceSpan span); + QuicMessageFrame(QuicMessageId message_id, absl::Span span); + QuicMessageFrame(QuicMessageId message_id, QuicMemSlice slice); QuicMessageFrame(const char* data, QuicPacketLength length); QuicMessageFrame(const QuicMessageFrame& other) = delete; diff --git a/gquiche/quic/core/frames/quic_new_connection_id_frame.h b/gquiche/quic/core/frames/quic_new_connection_id_frame.h index 531826b1..72e0d37b 100644 --- a/gquiche/quic/core/frames/quic_new_connection_id_frame.h +++ b/gquiche/quic/core/frames/quic_new_connection_id_frame.h @@ -32,7 +32,7 @@ struct QUIC_EXPORT_PRIVATE QuicNewConnectionIdFrame { QuicConnectionId connection_id = EmptyQuicConnectionId(); QuicConnectionIdSequenceNumber sequence_number = 0; StatelessResetToken stateless_reset_token; - uint64_t retire_prior_to; + uint64_t retire_prior_to = 0; }; } // namespace quic diff --git a/gquiche/quic/core/frames/quic_new_token_frame.cc b/gquiche/quic/core/frames/quic_new_token_frame.cc index f47af1cb..03e02112 100644 --- a/gquiche/quic/core/frames/quic_new_token_frame.cc +++ b/gquiche/quic/core/frames/quic_new_token_frame.cc @@ -6,7 +6,6 @@ #include "absl/strings/escaping.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { diff --git a/gquiche/quic/core/frames/quic_path_challenge_frame.cc b/gquiche/quic/core/frames/quic_path_challenge_frame.cc index 567cc184..3e6a6188 100644 --- a/gquiche/quic/core/frames/quic_path_challenge_frame.cc +++ b/gquiche/quic/core/frames/quic_path_challenge_frame.cc @@ -6,7 +6,6 @@ #include "absl/strings/escaping.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { diff --git a/gquiche/quic/core/frames/quic_path_challenge_frame.h b/gquiche/quic/core/frames/quic_path_challenge_frame.h index 6ebba4fa..60ec63eb 100644 --- a/gquiche/quic/core/frames/quic_path_challenge_frame.h +++ b/gquiche/quic/core/frames/quic_path_challenge_frame.h @@ -27,7 +27,7 @@ struct QUIC_EXPORT_PRIVATE QuicPathChallengeFrame { // and non-zero when sent. QuicControlFrameId control_frame_id = kInvalidControlFrameId; - QuicPathFrameBuffer data_buffer; + QuicPathFrameBuffer data_buffer{}; }; } // namespace quic diff --git a/gquiche/quic/core/frames/quic_path_response_frame.cc b/gquiche/quic/core/frames/quic_path_response_frame.cc index b662bec2..8f922a08 100644 --- a/gquiche/quic/core/frames/quic_path_response_frame.cc +++ b/gquiche/quic/core/frames/quic_path_response_frame.cc @@ -6,7 +6,6 @@ #include "absl/strings/escaping.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { diff --git a/gquiche/quic/core/frames/quic_path_response_frame.h b/gquiche/quic/core/frames/quic_path_response_frame.h index 4f48c1c2..5c4e1a24 100644 --- a/gquiche/quic/core/frames/quic_path_response_frame.h +++ b/gquiche/quic/core/frames/quic_path_response_frame.h @@ -27,7 +27,7 @@ struct QUIC_EXPORT_PRIVATE QuicPathResponseFrame { // and non-zero when sent. QuicControlFrameId control_frame_id = kInvalidControlFrameId; - QuicPathFrameBuffer data_buffer; + QuicPathFrameBuffer data_buffer{}; }; } // namespace quic diff --git a/gquiche/quic/core/frames/quic_rst_stream_frame.cc b/gquiche/quic/core/frames/quic_rst_stream_frame.cc index 5ba1c8df..9cbab198 100644 --- a/gquiche/quic/core/frames/quic_rst_stream_frame.cc +++ b/gquiche/quic/core/frames/quic_rst_stream_frame.cc @@ -18,12 +18,23 @@ QuicRstStreamFrame::QuicRstStreamFrame(QuicControlFrameId control_frame_id, ietf_error_code(RstStreamErrorCodeToIetfResetStreamErrorCode(error_code)), byte_offset(bytes_written) {} +QuicRstStreamFrame::QuicRstStreamFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicResetStreamError error, + QuicStreamOffset bytes_written) + : control_frame_id(control_frame_id), + stream_id(stream_id), + error_code(error.internal_code()), + ietf_error_code(error.ietf_application_code()), + byte_offset(bytes_written) {} + std::ostream& operator<<(std::ostream& os, const QuicRstStreamFrame& rst_frame) { os << "{ control_frame_id: " << rst_frame.control_frame_id << ", stream_id: " << rst_frame.stream_id << ", byte_offset: " << rst_frame.byte_offset - << ", error_code: " << rst_frame.error_code << " }\n"; + << ", error_code: " << rst_frame.error_code + << ", ietf_error_code: " << rst_frame.ietf_error_code << " }\n"; return os; } diff --git a/gquiche/quic/core/frames/quic_rst_stream_frame.h b/gquiche/quic/core/frames/quic_rst_stream_frame.h index d0893ba3..e90e43bf 100644 --- a/gquiche/quic/core/frames/quic_rst_stream_frame.h +++ b/gquiche/quic/core/frames/quic_rst_stream_frame.h @@ -16,13 +16,14 @@ namespace quic { struct QUIC_EXPORT_PRIVATE QuicRstStreamFrame { QuicRstStreamFrame() = default; QuicRstStreamFrame(QuicControlFrameId control_frame_id, - QuicStreamId stream_id, - QuicRstStreamErrorCode error_code, + QuicStreamId stream_id, QuicRstStreamErrorCode error_code, + QuicStreamOffset bytes_written); + QuicRstStreamFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, QuicResetStreamError error, QuicStreamOffset bytes_written); friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - const QuicRstStreamFrame& r); + std::ostream& os, const QuicRstStreamFrame& r); // A unique identifier of this control frame. 0 when this frame is received, // and non-zero when sent. @@ -45,6 +46,11 @@ struct QUIC_EXPORT_PRIVATE QuicRstStreamFrame { // that stream. This can be done through normal termination (data packet with // FIN) or through a RST. QuicStreamOffset byte_offset = 0; + + // Returns a tuple of both |error_code| and |ietf_error_code|. + QuicResetStreamError error() const { + return QuicResetStreamError(error_code, ietf_error_code); + } }; } // namespace quic diff --git a/gquiche/quic/core/frames/quic_stop_sending_frame.cc b/gquiche/quic/core/frames/quic_stop_sending_frame.cc index 897489a6..5b6d6c13 100644 --- a/gquiche/quic/core/frames/quic_stop_sending_frame.cc +++ b/gquiche/quic/core/frames/quic_stop_sending_frame.cc @@ -11,11 +11,16 @@ namespace quic { QuicStopSendingFrame::QuicStopSendingFrame(QuicControlFrameId control_frame_id, QuicStreamId stream_id, QuicRstStreamErrorCode error_code) + : QuicStopSendingFrame(control_frame_id, stream_id, + QuicResetStreamError::FromInternal(error_code)) {} + +QuicStopSendingFrame::QuicStopSendingFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, + QuicResetStreamError error) : control_frame_id(control_frame_id), stream_id(stream_id), - error_code(error_code), - ietf_error_code( - RstStreamErrorCodeToIetfResetStreamErrorCode(error_code)) {} + error_code(error.internal_code()), + ietf_error_code(error.ietf_application_code()) {} std::ostream& operator<<(std::ostream& os, const QuicStopSendingFrame& frame) { os << "{ control_frame_id: " << frame.control_frame_id diff --git a/gquiche/quic/core/frames/quic_stop_sending_frame.h b/gquiche/quic/core/frames/quic_stop_sending_frame.h index f8754fb8..24bf6e13 100644 --- a/gquiche/quic/core/frames/quic_stop_sending_frame.h +++ b/gquiche/quic/core/frames/quic_stop_sending_frame.h @@ -18,6 +18,8 @@ struct QUIC_EXPORT_PRIVATE QuicStopSendingFrame { QuicStopSendingFrame(QuicControlFrameId control_frame_id, QuicStreamId stream_id, QuicRstStreamErrorCode error_code); + QuicStopSendingFrame(QuicControlFrameId control_frame_id, + QuicStreamId stream_id, QuicResetStreamError error); friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( std::ostream& os, @@ -35,6 +37,11 @@ struct QUIC_EXPORT_PRIVATE QuicStopSendingFrame { // On-the-wire application error code of the frame. uint64_t ietf_error_code = 0; + + // Returns a tuple of both |error_code| and |ietf_error_code|. + QuicResetStreamError error() const { + return QuicResetStreamError(error_code, ietf_error_code); + } }; } // namespace quic diff --git a/gquiche/quic/core/frames/quic_streams_blocked_frame.h b/gquiche/quic/core/frames/quic_streams_blocked_frame.h index 15052653..07f32a00 100644 --- a/gquiche/quic/core/frames/quic_streams_blocked_frame.h +++ b/gquiche/quic/core/frames/quic_streams_blocked_frame.h @@ -35,7 +35,7 @@ struct QUIC_EXPORT_PRIVATE QuicStreamsBlockedFrame QuicControlFrameId control_frame_id = kInvalidControlFrameId; // The number of streams that the sender wishes to exceed - QuicStreamCount stream_count; + QuicStreamCount stream_count = 0; // Whether uni- or bi-directional streams bool unidirectional = false; diff --git a/gquiche/quic/core/frames/quic_window_update_frame.h b/gquiche/quic/core/frames/quic_window_update_frame.h index 070963a3..1de79acb 100644 --- a/gquiche/quic/core/frames/quic_window_update_frame.h +++ b/gquiche/quic/core/frames/quic_window_update_frame.h @@ -31,11 +31,11 @@ struct QUIC_EXPORT_PRIVATE QuicWindowUpdateFrame { // The stream this frame applies to. 0 is a special case meaning the overall // connection rather than a specific stream. - QuicStreamId stream_id; + QuicStreamId stream_id = 0; // Maximum data allowed in the stream or connection. The receiver of this // frame must not send data which would exceedes this restriction. - QuicByteCount max_data; + QuicByteCount max_data = 0; }; } // namespace quic diff --git a/gquiche/quic/core/http/capsule.cc b/gquiche/quic/core/http/capsule.cc new file mode 100644 index 00000000..23d55333 --- /dev/null +++ b/gquiche/quic/core/http/capsule.cc @@ -0,0 +1,749 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/http/capsule.h" + +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gquiche/quic/core/http/http_frames.h" +#include "gquiche/quic/core/quic_data_reader.h" +#include "gquiche/quic/core/quic_data_writer.h" +#include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/platform/api/quic_bug_tracker.h" +#include "gquiche/common/platform/api/quiche_logging.h" + +namespace quic { + +std::string CapsuleTypeToString(CapsuleType capsule_type) { + switch (capsule_type) { + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + return "REGISTER_DATAGRAM_CONTEXT"; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + return "CLOSE_DATAGRAM_CONTEXT"; + case CapsuleType::LEGACY_DATAGRAM: + return "LEGACY_DATAGRAM"; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + return "DATAGRAM_WITH_CONTEXT"; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + return "DATAGRAM_WITHOUT_CONTEXT"; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + return "REGISTER_DATAGRAM_NO_CONTEXT"; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + return "CLOSE_WEBTRANSPORT_SESSION"; + } + return absl::StrCat("Unknown(", static_cast(capsule_type), ")"); +} + +std::ostream& operator<<(std::ostream& os, const CapsuleType& capsule_type) { + os << CapsuleTypeToString(capsule_type); + return os; +} + +std::string DatagramFormatTypeToString( + DatagramFormatType datagram_format_type) { + switch (datagram_format_type) { + case DatagramFormatType::UDP_PAYLOAD: + return "UDP_PAYLOAD"; + case DatagramFormatType::WEBTRANSPORT: + return "WEBTRANSPORT"; + } + return absl::StrCat("Unknown(", static_cast(datagram_format_type), + ")"); +} + +std::ostream& operator<<(std::ostream& os, + const DatagramFormatType& datagram_format_type) { + os << DatagramFormatTypeToString(datagram_format_type); + return os; +} + +std::string ContextCloseCodeToString(ContextCloseCode context_close_code) { + switch (context_close_code) { + case ContextCloseCode::CLOSE_NO_ERROR: + return "NO_ERROR"; + case ContextCloseCode::UNKNOWN_FORMAT: + return "UNKNOWN_FORMAT"; + case ContextCloseCode::DENIED: + return "DENIED"; + case ContextCloseCode::RESOURCE_LIMIT: + return "RESOURCE_LIMIT"; + } + return absl::StrCat("Unknown(", static_cast(context_close_code), + ")"); +} + +std::ostream& operator<<(std::ostream& os, + const ContextCloseCode& context_close_code) { + os << ContextCloseCodeToString(context_close_code); + return os; +} + +Capsule::Capsule(CapsuleType capsule_type) : capsule_type_(capsule_type) { + switch (capsule_type) { + case CapsuleType::LEGACY_DATAGRAM: + static_assert( + std::is_standard_layout::value && + std::is_trivially_destructible::value, + "All capsule structs must have these properties"); + legacy_datagram_capsule_ = LegacyDatagramCapsule(); + break; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + static_assert( + std::is_standard_layout::value && + std::is_trivially_destructible::value, + "All capsule structs must have these properties"); + datagram_with_context_capsule_ = DatagramWithContextCapsule(); + break; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + static_assert( + std::is_standard_layout::value && + std::is_trivially_destructible< + DatagramWithoutContextCapsule>::value, + "All capsule structs must have these properties"); + datagram_without_context_capsule_ = DatagramWithoutContextCapsule(); + break; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + static_assert( + std::is_standard_layout::value && + std::is_trivially_destructible< + RegisterDatagramContextCapsule>::value, + "All capsule structs must have these properties"); + register_datagram_context_capsule_ = RegisterDatagramContextCapsule(); + break; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + static_assert( + std::is_standard_layout::value && + std::is_trivially_destructible< + RegisterDatagramNoContextCapsule>::value, + "All capsule structs must have these properties"); + register_datagram_no_context_capsule_ = + RegisterDatagramNoContextCapsule(); + break; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + static_assert( + std::is_standard_layout::value && + std::is_trivially_destructible< + CloseDatagramContextCapsule>::value, + "All capsule structs must have these properties"); + close_datagram_context_capsule_ = CloseDatagramContextCapsule(); + break; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + static_assert( + std::is_standard_layout::value && + std::is_trivially_destructible< + CloseWebTransportSessionCapsule>::value, + "All capsule structs must have these properties"); + close_web_transport_session_capsule_ = CloseWebTransportSessionCapsule(); + break; + default: + unknown_capsule_data_ = absl::string_view(); + break; + } +} + +// static +Capsule Capsule::LegacyDatagram( + absl::optional context_id, + absl::string_view http_datagram_payload) { + Capsule capsule(CapsuleType::LEGACY_DATAGRAM); + capsule.legacy_datagram_capsule().context_id = context_id; + capsule.legacy_datagram_capsule().http_datagram_payload = + http_datagram_payload; + return capsule; +} + +// static +Capsule Capsule::DatagramWithContext(QuicDatagramContextId context_id, + absl::string_view http_datagram_payload) { + Capsule capsule(CapsuleType::DATAGRAM_WITH_CONTEXT); + capsule.datagram_with_context_capsule().context_id = context_id; + capsule.datagram_with_context_capsule().http_datagram_payload = + http_datagram_payload; + return capsule; +} + +// static +Capsule Capsule::DatagramWithoutContext( + absl::string_view http_datagram_payload) { + Capsule capsule(CapsuleType::DATAGRAM_WITHOUT_CONTEXT); + capsule.datagram_without_context_capsule().http_datagram_payload = + http_datagram_payload; + return capsule; +} + +// static +Capsule Capsule::RegisterDatagramContext( + QuicDatagramContextId context_id, DatagramFormatType format_type, + absl::string_view format_additional_data) { + Capsule capsule(CapsuleType::REGISTER_DATAGRAM_CONTEXT); + capsule.register_datagram_context_capsule().context_id = context_id; + capsule.register_datagram_context_capsule().format_type = format_type; + capsule.register_datagram_context_capsule().format_additional_data = + format_additional_data; + return capsule; +} + +// static +Capsule Capsule::RegisterDatagramNoContext( + DatagramFormatType format_type, absl::string_view format_additional_data) { + Capsule capsule(CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT); + capsule.register_datagram_no_context_capsule().format_type = format_type; + capsule.register_datagram_no_context_capsule().format_additional_data = + format_additional_data; + return capsule; +} + +// static +Capsule Capsule::CloseDatagramContext(QuicDatagramContextId context_id, + ContextCloseCode close_code, + absl::string_view close_details) { + Capsule capsule(CapsuleType::CLOSE_DATAGRAM_CONTEXT); + capsule.close_datagram_context_capsule().context_id = context_id; + capsule.close_datagram_context_capsule().close_code = close_code; + capsule.close_datagram_context_capsule().close_details = close_details; + return capsule; +} + +// static +Capsule Capsule::CloseWebTransportSession(WebTransportSessionError error_code, + absl::string_view error_message) { + Capsule capsule(CapsuleType::CLOSE_WEBTRANSPORT_SESSION); + capsule.close_web_transport_session_capsule().error_code = error_code; + capsule.close_web_transport_session_capsule().error_message = error_message; + return capsule; +} + +// static +Capsule Capsule::Unknown(uint64_t capsule_type, + absl::string_view unknown_capsule_data) { + Capsule capsule(static_cast(capsule_type)); + capsule.unknown_capsule_data() = unknown_capsule_data; + return capsule; +} + +Capsule& Capsule::operator=(const Capsule& other) { + capsule_type_ = other.capsule_type_; + switch (capsule_type_) { + case CapsuleType::LEGACY_DATAGRAM: + legacy_datagram_capsule_ = other.legacy_datagram_capsule_; + break; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + datagram_with_context_capsule_ = other.datagram_with_context_capsule_; + break; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + datagram_without_context_capsule_ = + other.datagram_without_context_capsule_; + break; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + register_datagram_context_capsule_ = + other.register_datagram_context_capsule_; + break; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + register_datagram_no_context_capsule_ = + other.register_datagram_no_context_capsule_; + break; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + close_datagram_context_capsule_ = other.close_datagram_context_capsule_; + break; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + close_web_transport_session_capsule_ = + other.close_web_transport_session_capsule_; + break; + default: + unknown_capsule_data_ = other.unknown_capsule_data_; + break; + } + return *this; +} + +Capsule::Capsule(const Capsule& other) : Capsule(other.capsule_type_) { + *this = other; +} + +bool Capsule::operator==(const Capsule& other) const { + if (capsule_type_ != other.capsule_type_) { + return false; + } + switch (capsule_type_) { + case CapsuleType::LEGACY_DATAGRAM: + return legacy_datagram_capsule_.context_id == + other.legacy_datagram_capsule_.context_id && + legacy_datagram_capsule_.http_datagram_payload == + other.legacy_datagram_capsule_.http_datagram_payload; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + return datagram_with_context_capsule_.context_id == + other.datagram_with_context_capsule_.context_id && + datagram_with_context_capsule_.http_datagram_payload == + other.datagram_with_context_capsule_.http_datagram_payload; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + return datagram_without_context_capsule_.http_datagram_payload == + other.datagram_without_context_capsule_.http_datagram_payload; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + return register_datagram_context_capsule_.context_id == + other.register_datagram_context_capsule_.context_id && + register_datagram_context_capsule_.format_type == + other.register_datagram_context_capsule_.format_type && + register_datagram_context_capsule_.format_additional_data == + other.register_datagram_context_capsule_ + .format_additional_data; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + return register_datagram_no_context_capsule_.format_type == + other.register_datagram_no_context_capsule_.format_type && + register_datagram_no_context_capsule_.format_additional_data == + other.register_datagram_no_context_capsule_ + .format_additional_data; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + return close_datagram_context_capsule_.context_id == + other.close_datagram_context_capsule_.context_id && + close_datagram_context_capsule_.close_code == + other.close_datagram_context_capsule_.close_code && + close_datagram_context_capsule_.close_details == + other.close_datagram_context_capsule_.close_details; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + return close_web_transport_session_capsule_.error_code == + other.close_web_transport_session_capsule_.error_code && + close_web_transport_session_capsule_.error_message == + other.close_web_transport_session_capsule_.error_message; + default: + return unknown_capsule_data_ == other.unknown_capsule_data_; + } +} + +std::string Capsule::ToString() const { + std::string rv = CapsuleTypeToString(capsule_type_); + switch (capsule_type_) { + case CapsuleType::LEGACY_DATAGRAM: + if (legacy_datagram_capsule_.context_id.has_value()) { + absl::StrAppend(&rv, "(", legacy_datagram_capsule_.context_id.value(), + ")"); + } + absl::StrAppend(&rv, "[", + absl::BytesToHexString( + legacy_datagram_capsule_.http_datagram_payload), + "]"); + break; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + absl::StrAppend(&rv, "(", datagram_with_context_capsule_.context_id, ")[", + absl::BytesToHexString( + datagram_with_context_capsule_.http_datagram_payload), + "]"); + break; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + absl::StrAppend( + &rv, "[", + absl::BytesToHexString( + datagram_without_context_capsule_.http_datagram_payload), + "]"); + break; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + absl::StrAppend( + &rv, "(context_id=", register_datagram_context_capsule_.context_id, + ",format_type=", + DatagramFormatTypeToString( + register_datagram_context_capsule_.format_type), + "){", + absl::BytesToHexString( + register_datagram_context_capsule_.format_additional_data), + "}"); + break; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + absl::StrAppend( + &rv, "(format_type=", + DatagramFormatTypeToString( + register_datagram_no_context_capsule_.format_type), + "){", + absl::BytesToHexString( + register_datagram_no_context_capsule_.format_additional_data), + "}"); + break; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + absl::StrAppend( + &rv, "(context_id=", close_datagram_context_capsule_.context_id, + ",close_code=", + ContextCloseCodeToString(close_datagram_context_capsule_.close_code), + ",close_details=\"", + absl::BytesToHexString(close_datagram_context_capsule_.close_details), + "\")"); + break; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + absl::StrAppend( + &rv, "(error_code=", close_web_transport_session_capsule_.error_code, + ",error_message=\"", + close_web_transport_session_capsule_.error_message, "\")"); + break; + default: + absl::StrAppend(&rv, "[", absl::BytesToHexString(unknown_capsule_data_), + "]"); + break; + } + return rv; +} + +std::ostream& operator<<(std::ostream& os, const Capsule& capsule) { + os << capsule.ToString(); + return os; +} + +CapsuleParser::CapsuleParser(Visitor* visitor) : visitor_(visitor) { + QUICHE_DCHECK_NE(visitor_, nullptr); +} + +QuicBuffer SerializeCapsule(const Capsule& capsule, + QuicBufferAllocator* allocator) { + QuicByteCount capsule_type_length = QuicDataWriter::GetVarInt62Len( + static_cast(capsule.capsule_type())); + QuicByteCount capsule_data_length; + switch (capsule.capsule_type()) { + case CapsuleType::LEGACY_DATAGRAM: + capsule_data_length = + capsule.legacy_datagram_capsule().http_datagram_payload.length(); + if (capsule.legacy_datagram_capsule().context_id.has_value()) { + capsule_data_length += QuicDataWriter::GetVarInt62Len( + capsule.legacy_datagram_capsule().context_id.value()); + } + break; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + capsule_data_length = + QuicDataWriter::GetVarInt62Len( + capsule.datagram_with_context_capsule().context_id) + + capsule.datagram_with_context_capsule() + .http_datagram_payload.length(); + break; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + capsule_data_length = capsule.datagram_without_context_capsule() + .http_datagram_payload.length(); + break; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + capsule_data_length = + QuicDataWriter::GetVarInt62Len( + capsule.register_datagram_context_capsule().context_id) + + QuicDataWriter::GetVarInt62Len(static_cast( + capsule.register_datagram_context_capsule().format_type)) + + capsule.register_datagram_context_capsule() + .format_additional_data.length(); + break; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + capsule_data_length = + QuicDataWriter::GetVarInt62Len(static_cast( + capsule.register_datagram_no_context_capsule().format_type)) + + capsule.register_datagram_no_context_capsule() + .format_additional_data.length(); + break; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + capsule_data_length = + QuicDataWriter::GetVarInt62Len( + capsule.close_datagram_context_capsule().context_id) + + QuicDataWriter::GetVarInt62Len(static_cast( + capsule.close_datagram_context_capsule().close_code)) + + capsule.close_datagram_context_capsule().close_details.length(); + break; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + capsule_data_length = + sizeof(WebTransportSessionError) + + capsule.close_web_transport_session_capsule().error_message.size(); + break; + default: + capsule_data_length = capsule.unknown_capsule_data().length(); + break; + } + QuicByteCount capsule_length_length = + QuicDataWriter::GetVarInt62Len(capsule_data_length); + QuicByteCount total_capsule_length = + capsule_type_length + capsule_length_length + capsule_data_length; + QuicBuffer buffer(allocator, total_capsule_length); + QuicDataWriter writer(buffer.size(), buffer.data()); + if (!writer.WriteVarInt62(static_cast(capsule.capsule_type()))) { + QUIC_BUG(capsule type write fail) << "Failed to write CAPSULE type"; + return QuicBuffer(); + } + if (!writer.WriteVarInt62(capsule_data_length)) { + QUIC_BUG(capsule length write fail) << "Failed to write CAPSULE length"; + return QuicBuffer(); + } + switch (capsule.capsule_type()) { + case CapsuleType::LEGACY_DATAGRAM: + if (capsule.legacy_datagram_capsule().context_id.has_value()) { + if (!writer.WriteVarInt62( + capsule.legacy_datagram_capsule().context_id.value())) { + QUIC_BUG(datagram capsule context ID write fail) + << "Failed to write LEGACY_DATAGRAM CAPSULE context ID"; + return QuicBuffer(); + } + } + if (!writer.WriteStringPiece( + capsule.legacy_datagram_capsule().http_datagram_payload)) { + QUIC_BUG(datagram capsule payload write fail) + << "Failed to write LEGACY_DATAGRAM CAPSULE payload"; + return QuicBuffer(); + } + break; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + if (!writer.WriteVarInt62( + capsule.datagram_with_context_capsule().context_id)) { + QUIC_BUG(datagram capsule context ID write fail) + << "Failed to write DATAGRAM_WITH_CONTEXT CAPSULE context ID"; + return QuicBuffer(); + } + if (!writer.WriteStringPiece( + capsule.datagram_with_context_capsule().http_datagram_payload)) { + QUIC_BUG(datagram capsule payload write fail) + << "Failed to write DATAGRAM_WITH_CONTEXT CAPSULE payload"; + return QuicBuffer(); + } + break; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + if (!writer.WriteStringPiece(capsule.datagram_without_context_capsule() + .http_datagram_payload)) { + QUIC_BUG(datagram capsule payload write fail) + << "Failed to write DATAGRAM_WITHOUT_CONTEXT CAPSULE payload"; + return QuicBuffer(); + } + break; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + if (!writer.WriteVarInt62( + capsule.register_datagram_context_capsule().context_id)) { + QUIC_BUG(register context capsule context ID write fail) + << "Failed to write REGISTER_DATAGRAM_CONTEXT CAPSULE context ID"; + return QuicBuffer(); + } + if (!writer.WriteVarInt62(static_cast( + capsule.register_datagram_context_capsule().format_type))) { + QUIC_BUG(register context capsule format type write fail) + << "Failed to write REGISTER_DATAGRAM_CONTEXT CAPSULE format type"; + return QuicBuffer(); + } + if (!writer.WriteStringPiece(capsule.register_datagram_context_capsule() + .format_additional_data)) { + QUIC_BUG(register context capsule additional data write fail) + << "Failed to write REGISTER_DATAGRAM_CONTEXT CAPSULE additional " + "data"; + return QuicBuffer(); + } + break; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + if (!writer.WriteVarInt62(static_cast( + capsule.register_datagram_no_context_capsule().format_type))) { + QUIC_BUG(register no context capsule format type write fail) + << "Failed to write REGISTER_DATAGRAM_NO_CONTEXT CAPSULE format " + "type"; + return QuicBuffer(); + } + if (!writer.WriteStringPiece( + capsule.register_datagram_no_context_capsule() + .format_additional_data)) { + QUIC_BUG(register no context capsule additional data write fail) + << "Failed to write REGISTER_DATAGRAM_NO_CONTEXT CAPSULE " + "additional data"; + return QuicBuffer(); + } + break; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + if (!writer.WriteVarInt62( + capsule.close_datagram_context_capsule().context_id)) { + QUIC_BUG(close context capsule context ID write fail) + << "Failed to write CLOSE_DATAGRAM_CONTEXT CAPSULE context ID"; + return QuicBuffer(); + } + if (!writer.WriteVarInt62(static_cast( + capsule.close_datagram_context_capsule().close_code))) { + QUIC_BUG(close context capsule close code write fail) + << "Failed to write CLOSE_DATAGRAM_CONTEXT CAPSULE close code"; + return QuicBuffer(); + } + if (!writer.WriteStringPiece( + capsule.close_datagram_context_capsule().close_details)) { + QUIC_BUG(close context capsule close details write fail) + << "Failed to write CLOSE_DATAGRAM_CONTEXT CAPSULE close details"; + return QuicBuffer(); + } + break; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + if (!writer.WriteUInt32( + capsule.close_web_transport_session_capsule().error_code)) { + QUIC_BUG(close webtransport session capsule error code write fail) + << "Failed to write CLOSE_WEBTRANSPORT_SESSION error code"; + return QuicBuffer(); + } + if (!writer.WriteStringPiece( + capsule.close_web_transport_session_capsule().error_message)) { + QUIC_BUG(close webtransport session capsule error message write fail) + << "Failed to write CLOSE_WEBTRANSPORT_SESSION error message"; + return QuicBuffer(); + } + break; + default: + if (!writer.WriteStringPiece(capsule.unknown_capsule_data())) { + QUIC_BUG(capsule data write fail) << "Failed to write CAPSULE data"; + return QuicBuffer(); + } + break; + } + if (writer.remaining() != 0) { + QUIC_BUG(capsule write length mismatch) + << "CAPSULE serialization wrote " << writer.length() << " instead of " + << writer.capacity(); + return QuicBuffer(); + } + return buffer; +} + +bool CapsuleParser::IngestCapsuleFragment(absl::string_view capsule_fragment) { + if (parsing_error_occurred_) { + return false; + } + absl::StrAppend(&buffered_data_, capsule_fragment); + while (true) { + const size_t buffered_data_read = AttemptParseCapsule(); + if (parsing_error_occurred_) { + QUICHE_DCHECK_EQ(buffered_data_read, 0u); + buffered_data_.clear(); + return false; + } + if (buffered_data_read == 0) { + break; + } + buffered_data_.erase(0, buffered_data_read); + } + static constexpr size_t kMaxCapsuleBufferSize = 1024 * 1024; + if (buffered_data_.size() > kMaxCapsuleBufferSize) { + buffered_data_.clear(); + ReportParseFailure("Refusing to buffer too much capsule data"); + return false; + } + return true; +} + +size_t CapsuleParser::AttemptParseCapsule() { + QUICHE_DCHECK(!parsing_error_occurred_); + if (buffered_data_.empty()) { + return 0; + } + QuicDataReader capsule_fragment_reader(buffered_data_); + uint64_t capsule_type64; + if (!capsule_fragment_reader.ReadVarInt62(&capsule_type64)) { + QUIC_DVLOG(2) << "Partial read: not enough data to read capsule type"; + return 0; + } + absl::string_view capsule_data; + if (!capsule_fragment_reader.ReadStringPieceVarInt62(&capsule_data)) { + QUIC_DVLOG(2) << "Partial read: not enough data to read capsule length or " + "full capsule data"; + return 0; + } + QuicDataReader capsule_data_reader(capsule_data); + Capsule capsule(static_cast(capsule_type64)); + switch (capsule.capsule_type()) { + case CapsuleType::LEGACY_DATAGRAM: + if (datagram_context_id_present_) { + uint64_t context_id; + if (!capsule_data_reader.ReadVarInt62(&context_id)) { + ReportParseFailure( + "Unable to parse capsule LEGACY_DATAGRAM context ID"); + return 0; + } + capsule.legacy_datagram_capsule().context_id = context_id; + } + capsule.legacy_datagram_capsule().http_datagram_payload = + capsule_data_reader.ReadRemainingPayload(); + break; + case CapsuleType::DATAGRAM_WITH_CONTEXT: + uint64_t context_id; + if (!capsule_data_reader.ReadVarInt62(&context_id)) { + ReportParseFailure( + "Unable to parse capsule DATAGRAM_WITH_CONTEXT context ID"); + return 0; + } + capsule.datagram_with_context_capsule().context_id = context_id; + capsule.datagram_with_context_capsule().http_datagram_payload = + capsule_data_reader.ReadRemainingPayload(); + break; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: + capsule.datagram_without_context_capsule().http_datagram_payload = + capsule_data_reader.ReadRemainingPayload(); + break; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: + if (!capsule_data_reader.ReadVarInt62( + &capsule.register_datagram_context_capsule().context_id)) { + ReportParseFailure( + "Unable to parse capsule REGISTER_DATAGRAM_CONTEXT context ID"); + return 0; + } + if (!capsule_data_reader.ReadVarInt62(reinterpret_cast( + &capsule.register_datagram_context_capsule().format_type))) { + ReportParseFailure( + "Unable to parse capsule REGISTER_DATAGRAM_CONTEXT format type"); + return 0; + } + capsule.register_datagram_context_capsule().format_additional_data = + capsule_data_reader.ReadRemainingPayload(); + break; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: + if (!capsule_data_reader.ReadVarInt62(reinterpret_cast( + &capsule.register_datagram_no_context_capsule().format_type))) { + ReportParseFailure( + "Unable to parse capsule REGISTER_DATAGRAM_NO_CONTEXT format type"); + return 0; + } + capsule.register_datagram_no_context_capsule().format_additional_data = + capsule_data_reader.ReadRemainingPayload(); + break; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: + if (!capsule_data_reader.ReadVarInt62( + &capsule.close_datagram_context_capsule().context_id)) { + ReportParseFailure( + "Unable to parse capsule CLOSE_DATAGRAM_CONTEXT context ID"); + return 0; + } + if (!capsule_data_reader.ReadVarInt62(reinterpret_cast( + &capsule.close_datagram_context_capsule().close_code))) { + ReportParseFailure( + "Unable to parse capsule CLOSE_DATAGRAM_CONTEXT close code"); + return 0; + } + capsule.close_datagram_context_capsule().close_details = + capsule_data_reader.ReadRemainingPayload(); + break; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: + if (!capsule_data_reader.ReadUInt32( + &capsule.close_web_transport_session_capsule().error_code)) { + ReportParseFailure( + "Unable to parse capsule CLOSE_WEBTRANSPORT_SESSION error code"); + return 0; + } + capsule.close_web_transport_session_capsule().error_message = + capsule_data_reader.ReadRemainingPayload(); + break; + default: + capsule.unknown_capsule_data() = + capsule_data_reader.ReadRemainingPayload(); + } + if (!visitor_->OnCapsule(capsule)) { + ReportParseFailure("Visitor failed to process capsule"); + return 0; + } + return capsule_fragment_reader.PreviouslyReadPayload().length(); +} + +void CapsuleParser::ReportParseFailure(const std::string& error_message) { + if (parsing_error_occurred_) { + QUIC_BUG(multiple parse errors) << "Experienced multiple parse failures"; + return; + } + parsing_error_occurred_ = true; + visitor_->OnCapsuleParseFailure(error_message); +} + +void CapsuleParser::ErrorIfThereIsRemainingBufferedData() { + if (parsing_error_occurred_) { + return; + } + if (!buffered_data_.empty()) { + ReportParseFailure("Incomplete capsule left at the end of the stream"); + } +} + +} // namespace quic diff --git a/gquiche/quic/core/http/capsule.h b/gquiche/quic/core/http/capsule.h new file mode 100644 index 00000000..bce5c37d --- /dev/null +++ b/gquiche/quic/core/http/capsule.h @@ -0,0 +1,291 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_HTTP_CAPSULE_H_ +#define QUICHE_QUIC_CORE_HTTP_CAPSULE_H_ + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "gquiche/quic/core/quic_buffer_allocator.h" +#include "gquiche/quic/core/quic_data_reader.h" +#include "gquiche/quic/core/quic_types.h" +#include "gquiche/common/platform/api/quiche_logging.h" + +namespace quic { + +enum class CapsuleType : uint64_t { + // Casing in this enum matches the IETF specification. + LEGACY_DATAGRAM = 0xff37a0, // draft-ietf-masque-h3-datagram-04 + REGISTER_DATAGRAM_CONTEXT = 0xff37a1, + REGISTER_DATAGRAM_NO_CONTEXT = 0xff37a2, + CLOSE_DATAGRAM_CONTEXT = 0xff37a3, + DATAGRAM_WITH_CONTEXT = 0xff37a4, + DATAGRAM_WITHOUT_CONTEXT = 0xff37a5, + CLOSE_WEBTRANSPORT_SESSION = 0x2843, +}; + +QUIC_EXPORT_PRIVATE std::string CapsuleTypeToString(CapsuleType capsule_type); +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const CapsuleType& capsule_type); + +enum class DatagramFormatType : uint64_t { + // Casing in this enum matches the IETF specification. + UDP_PAYLOAD = 0xff6f00, + WEBTRANSPORT = 0xff7c00, +}; + +QUIC_EXPORT_PRIVATE std::string DatagramFormatTypeToString( + DatagramFormatType datagram_format_type); +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const DatagramFormatType& datagram_format_type); + +enum class ContextCloseCode : uint64_t { + // Casing in this enum matches the IETF specification. + CLOSE_NO_ERROR = 0xff78a0, // NO_ERROR already exists in winerror.h. + UNKNOWN_FORMAT = 0xff78a1, + DENIED = 0xff78a2, + RESOURCE_LIMIT = 0xff78a3, +}; + +QUIC_EXPORT_PRIVATE std::string ContextCloseCodeToString( + ContextCloseCode context_close_code); +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const ContextCloseCode& context_close_code); + +struct QUIC_EXPORT_PRIVATE LegacyDatagramCapsule { + absl::optional context_id; + absl::string_view http_datagram_payload; +}; +struct QUIC_EXPORT_PRIVATE DatagramWithContextCapsule { + QuicDatagramContextId context_id; + absl::string_view http_datagram_payload; +}; +struct QUIC_EXPORT_PRIVATE DatagramWithoutContextCapsule { + absl::string_view http_datagram_payload; +}; +struct QUIC_EXPORT_PRIVATE RegisterDatagramContextCapsule { + QuicDatagramContextId context_id; + DatagramFormatType format_type; + absl::string_view format_additional_data; +}; +struct QUIC_EXPORT_PRIVATE RegisterDatagramNoContextCapsule { + DatagramFormatType format_type; + absl::string_view format_additional_data; +}; +struct QUIC_EXPORT_PRIVATE CloseDatagramContextCapsule { + QuicDatagramContextId context_id; + ContextCloseCode close_code; + absl::string_view close_details; +}; +struct QUIC_EXPORT_PRIVATE CloseWebTransportSessionCapsule { + WebTransportSessionError error_code; + absl::string_view error_message; +}; + +// Capsule from draft-ietf-masque-h3-datagram. +// IMPORTANT NOTE: Capsule does not own any of the absl::string_view memory it +// points to. Strings saved into a capsule must outlive the capsule object. Any +// code that sees a capsule in a callback needs to either process it immediately +// or perform its own deep copy. +class QUIC_EXPORT_PRIVATE Capsule { + public: + static Capsule LegacyDatagram( + absl::optional context_id = absl::nullopt, + absl::string_view http_datagram_payload = absl::string_view()); + static Capsule DatagramWithContext( + QuicDatagramContextId context_id, + absl::string_view http_datagram_payload = absl::string_view()); + static Capsule DatagramWithoutContext( + absl::string_view http_datagram_payload = absl::string_view()); + static Capsule RegisterDatagramContext( + QuicDatagramContextId context_id, DatagramFormatType format_type, + absl::string_view format_additional_data = absl::string_view()); + static Capsule RegisterDatagramNoContext( + DatagramFormatType format_type, + absl::string_view format_additional_data = absl::string_view()); + static Capsule CloseDatagramContext( + QuicDatagramContextId context_id, + ContextCloseCode close_code = ContextCloseCode::CLOSE_NO_ERROR, + absl::string_view close_details = absl::string_view()); + static Capsule CloseWebTransportSession( + WebTransportSessionError error_code = 0, + absl::string_view error_message = ""); + static Capsule Unknown( + uint64_t capsule_type, + absl::string_view unknown_capsule_data = absl::string_view()); + + explicit Capsule(CapsuleType capsule_type); + Capsule(const Capsule& other); + Capsule& operator=(const Capsule& other); + bool operator==(const Capsule& other) const; + + // Human-readable information string for debugging purposes. + std::string ToString() const; + friend QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + const Capsule& capsule); + + CapsuleType capsule_type() const { return capsule_type_; } + LegacyDatagramCapsule& legacy_datagram_capsule() { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::LEGACY_DATAGRAM); + return legacy_datagram_capsule_; + } + const LegacyDatagramCapsule& legacy_datagram_capsule() const { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::LEGACY_DATAGRAM); + return legacy_datagram_capsule_; + } + DatagramWithContextCapsule& datagram_with_context_capsule() { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::DATAGRAM_WITH_CONTEXT); + return datagram_with_context_capsule_; + } + const DatagramWithContextCapsule& datagram_with_context_capsule() const { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::DATAGRAM_WITH_CONTEXT); + return datagram_with_context_capsule_; + } + DatagramWithoutContextCapsule& datagram_without_context_capsule() { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::DATAGRAM_WITHOUT_CONTEXT); + return datagram_without_context_capsule_; + } + const DatagramWithoutContextCapsule& datagram_without_context_capsule() + const { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::DATAGRAM_WITHOUT_CONTEXT); + return datagram_without_context_capsule_; + } + RegisterDatagramContextCapsule& register_datagram_context_capsule() { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::REGISTER_DATAGRAM_CONTEXT); + return register_datagram_context_capsule_; + } + const RegisterDatagramContextCapsule& register_datagram_context_capsule() + const { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::REGISTER_DATAGRAM_CONTEXT); + return register_datagram_context_capsule_; + } + RegisterDatagramNoContextCapsule& register_datagram_no_context_capsule() { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT); + return register_datagram_no_context_capsule_; + } + const RegisterDatagramNoContextCapsule& register_datagram_no_context_capsule() + const { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT); + return register_datagram_no_context_capsule_; + } + CloseDatagramContextCapsule& close_datagram_context_capsule() { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::CLOSE_DATAGRAM_CONTEXT); + return close_datagram_context_capsule_; + } + const CloseDatagramContextCapsule& close_datagram_context_capsule() const { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::CLOSE_DATAGRAM_CONTEXT); + return close_datagram_context_capsule_; + } + CloseWebTransportSessionCapsule& close_web_transport_session_capsule() { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::CLOSE_WEBTRANSPORT_SESSION); + return close_web_transport_session_capsule_; + } + const CloseWebTransportSessionCapsule& close_web_transport_session_capsule() + const { + QUICHE_DCHECK_EQ(capsule_type_, CapsuleType::CLOSE_WEBTRANSPORT_SESSION); + return close_web_transport_session_capsule_; + } + absl::string_view& unknown_capsule_data() { + QUICHE_DCHECK(capsule_type_ != CapsuleType::LEGACY_DATAGRAM && + capsule_type_ != CapsuleType::DATAGRAM_WITH_CONTEXT && + capsule_type_ != CapsuleType::DATAGRAM_WITHOUT_CONTEXT && + capsule_type_ != CapsuleType::REGISTER_DATAGRAM_CONTEXT && + capsule_type_ != CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT && + capsule_type_ != CapsuleType::CLOSE_DATAGRAM_CONTEXT && + capsule_type_ != CapsuleType::CLOSE_WEBTRANSPORT_SESSION) + << capsule_type_; + return unknown_capsule_data_; + } + const absl::string_view& unknown_capsule_data() const { + QUICHE_DCHECK(capsule_type_ != CapsuleType::LEGACY_DATAGRAM && + capsule_type_ != CapsuleType::DATAGRAM_WITH_CONTEXT && + capsule_type_ != CapsuleType::DATAGRAM_WITHOUT_CONTEXT && + capsule_type_ != CapsuleType::REGISTER_DATAGRAM_CONTEXT && + capsule_type_ != CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT && + capsule_type_ != CapsuleType::CLOSE_DATAGRAM_CONTEXT && + capsule_type_ != CapsuleType::CLOSE_WEBTRANSPORT_SESSION) + << capsule_type_; + return unknown_capsule_data_; + } + + private: + CapsuleType capsule_type_; + union { + LegacyDatagramCapsule legacy_datagram_capsule_; + DatagramWithContextCapsule datagram_with_context_capsule_; + DatagramWithoutContextCapsule datagram_without_context_capsule_; + RegisterDatagramContextCapsule register_datagram_context_capsule_; + RegisterDatagramNoContextCapsule register_datagram_no_context_capsule_; + CloseDatagramContextCapsule close_datagram_context_capsule_; + CloseWebTransportSessionCapsule close_web_transport_session_capsule_; + absl::string_view unknown_capsule_data_; + }; +}; + +namespace test { +class CapsuleParserPeer; +} // namespace test + +class QUIC_EXPORT_PRIVATE CapsuleParser { + public: + class QUIC_EXPORT_PRIVATE Visitor { + public: + virtual ~Visitor() {} + + // Called when a capsule has been successfully parsed. The return value + // indicates whether the contents of the capsule are valid: if false is + // returned, the parse operation will be considered failed and + // OnCapsuleParseFailure will be called. Note that since Capsule does not + // own the memory backing its string_views, that memory is only valid until + // this callback returns. Visitors that wish to access the capsule later + // MUST make a deep copy before this returns. + virtual bool OnCapsule(const Capsule& capsule) = 0; + + virtual void OnCapsuleParseFailure(const std::string& error_message) = 0; + }; + + // |visitor| must be non-null, and must outlive CapsuleParser. + explicit CapsuleParser(Visitor* visitor); + + void set_datagram_context_id_present(bool datagram_context_id_present) { + datagram_context_id_present_ = datagram_context_id_present; + } + + // Ingests a capsule fragment (any fragment of bytes from the capsule data + // stream) and parses and complete capsules it encounters. Returns false if a + // parsing error occurred. + bool IngestCapsuleFragment(absl::string_view capsule_fragment); + + void ErrorIfThereIsRemainingBufferedData(); + + friend class test::CapsuleParserPeer; + + private: + // Attempts to parse a single capsule from |buffered_data_|. If a full capsule + // is not available, returns 0. If a parsing error occurs, returns 0. + // Otherwise, returns the number of bytes in the parsed capsule. + size_t AttemptParseCapsule(); + void ReportParseFailure(const std::string& error_message); + + // Whether HTTP Datagram Context IDs are present. + bool datagram_context_id_present_ = false; + // Whether a parsing error has occurred. + bool parsing_error_occurred_ = false; + // Visitor which will receive callbacks, unowned. + Visitor* visitor_; + + std::string buffered_data_; +}; + +// Serializes |capsule| into a newly allocated buffer. +QUIC_EXPORT_PRIVATE QuicBuffer SerializeCapsule(const Capsule& capsule, + QuicBufferAllocator* allocator); + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_HTTP_CAPSULE_H_ diff --git a/gquiche/quic/core/http/capsule_test.cc b/gquiche/quic/core/http/capsule_test.cc new file mode 100644 index 00000000..3b74746e --- /dev/null +++ b/gquiche/quic/core/http/capsule_test.cc @@ -0,0 +1,364 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/http/capsule.h" + +#include +#include +#include +#include + +#include "absl/strings/escaping.h" +#include "absl/strings/string_view.h" +#include "gquiche/quic/platform/api/quic_test.h" +#include "gquiche/quic/test_tools/quic_test_utils.h" +#include "gquiche/common/test_tools/quiche_test_utils.h" + +using ::testing::_; +using ::testing::InSequence; +using ::testing::Return; + +namespace quic { +namespace test { + +class CapsuleParserPeer { + public: + static std::string* buffered_data(CapsuleParser* capsule_parser) { + return &capsule_parser->buffered_data_; + } +}; + +namespace { + +constexpr DatagramFormatType kFakeFormatType = + static_cast(0x123456); +constexpr ContextCloseCode kFakeCloseCode = + static_cast(0x654321); + +class MockCapsuleParserVisitor : public CapsuleParser::Visitor { + public: + MockCapsuleParserVisitor() { + ON_CALL(*this, OnCapsule(_)).WillByDefault(Return(true)); + } + ~MockCapsuleParserVisitor() override = default; + MOCK_METHOD(bool, OnCapsule, (const Capsule& capsule), (override)); + MOCK_METHOD(void, OnCapsuleParseFailure, (const std::string& error_message), + (override)); +}; + +class CapsuleTest : public QuicTest { + public: + CapsuleTest() : capsule_parser_(&visitor_) {} + + void ValidateParserIsEmpty() { + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + EXPECT_CALL(visitor_, OnCapsuleParseFailure(_)).Times(0); + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); + } + + void TestSerialization(const Capsule& capsule, + const std::string& expected_bytes) { + QuicBuffer serialized_capsule = + SerializeCapsule(capsule, SimpleBufferAllocator::Get()); + quiche::test::CompareCharArraysWithHexError( + "Serialized capsule", serialized_capsule.data(), + serialized_capsule.size(), expected_bytes.data(), + expected_bytes.size()); + } + + ::testing::StrictMock visitor_; + CapsuleParser capsule_parser_; +}; + +TEST_F(CapsuleTest, LegacyDatagramCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a0" // LEGACY_DATAGRAM capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + ); + std::string datagram_payload = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = + Capsule::LegacyDatagram(/*context_id=*/absl::nullopt, datagram_payload); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, LegacyDatagramCapsuleWithContext) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a0" // LEGACY_DATAGRAM capsule type + "09" // capsule length + "04" // context ID + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + ); + capsule_parser_.set_datagram_context_id_present(true); + std::string datagram_payload = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = + Capsule::LegacyDatagram(/*context_id=*/4, datagram_payload); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, DatagramWithoutContextCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a5" // DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + ); + std::string datagram_payload = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = Capsule::DatagramWithoutContext(datagram_payload); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, DatagramWithContextCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a4" // DATAGRAM_WITH_CONTEXT capsule type + "09" // capsule length + "04" // context ID + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + ); + std::string datagram_payload = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = + Capsule::DatagramWithContext(/*context_id=*/4, datagram_payload); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, RegisterContextCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a1" // REGISTER_DATAGRAM_CONTEXT capsule type + "0d" // capsule length + "04" // context ID + "80123456" // 0x123456 datagram format type + "f1f2f3f4f5f6f7f8" // format additional data + ); + std::string format_additional_data = + absl::HexStringToBytes("f1f2f3f4f5f6f7f8"); + Capsule expected_capsule = Capsule::RegisterDatagramContext( + /*context_id=*/4, kFakeFormatType, format_additional_data); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, RegisterNoContextCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a2" // REGISTER_DATAGRAM_NO_CONTEXT capsule type + "0c" // capsule length + "80123456" // 0x123456 datagram format type + "f1f2f3f4f5f6f7f8" // format additional data + ); + std::string format_additional_data = + absl::HexStringToBytes("f1f2f3f4f5f6f7f8"); + Capsule expected_capsule = Capsule::RegisterDatagramNoContext( + kFakeFormatType, format_additional_data); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, CloseContextCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a3" // CLOSE_DATAGRAM_CONTEXT capsule type + "27" // capsule length + "04" // context ID + "80654321" // 0x654321 close code + ); + std::string close_details = "All your contexts are belong to us"; + capsule_fragment += close_details; + Capsule expected_capsule = Capsule::CloseDatagramContext( + /*context_id=*/4, kFakeCloseCode, close_details); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, CloseWebTransportStreamCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "6843" // CLOSE_WEBTRANSPORT_STREAM capsule type + "09" // capsule length + "00001234" // 0x1234 error code + "68656c6c6f" // "hello" error message + ); + Capsule expected_capsule = Capsule::CloseWebTransportSession( + /*error_code=*/0x1234, /*error_message=*/"hello"); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, UnknownCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "33" // unknown capsule type of 0x33 + "08" // capsule length + "a1a2a3a4a5a6a7a8" // unknown capsule data + ); + std::string unknown_capsule_data = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + Capsule expected_capsule = Capsule::Unknown(0x33, unknown_capsule_data); + { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); + TestSerialization(expected_capsule, capsule_fragment); +} + +TEST_F(CapsuleTest, TwoCapsules) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a5" // DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + "80ff37a5" // DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // capsule length + "b1b2b3b4b5b6b7b8" // HTTP Datagram payload + ); + std::string datagram_payload1 = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + std::string datagram_payload2 = absl::HexStringToBytes("b1b2b3b4b5b6b7b8"); + Capsule expected_capsule1 = + Capsule::DatagramWithoutContext(datagram_payload1); + Capsule expected_capsule2 = + Capsule::DatagramWithoutContext(datagram_payload2); + { + InSequence s; + EXPECT_CALL(visitor_, OnCapsule(expected_capsule1)); + EXPECT_CALL(visitor_, OnCapsule(expected_capsule2)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + ValidateParserIsEmpty(); +} + +TEST_F(CapsuleTest, TwoCapsulesPartialReads) { + std::string capsule_fragment1 = absl::HexStringToBytes( + "80ff37a5" // first capsule DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // frist capsule length + "a1a2a3a4" // first half of HTTP Datagram payload of first capsule + ); + std::string capsule_fragment2 = absl::HexStringToBytes( + "a5a6a7a8" // second half of HTTP Datagram payload 1 + "80ff37a5" // second capsule DATAGRAM_WITHOUT_CONTEXT capsule type + ); + std::string capsule_fragment3 = absl::HexStringToBytes( + "08" // second capsule length + "b1b2b3b4b5b6b7b8" // HTTP Datagram payload of second capsule + ); + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + std::string datagram_payload1 = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + std::string datagram_payload2 = absl::HexStringToBytes("b1b2b3b4b5b6b7b8"); + Capsule expected_capsule1 = + Capsule::DatagramWithoutContext(datagram_payload1); + Capsule expected_capsule2 = + Capsule::DatagramWithoutContext(datagram_payload2); + { + InSequence s; + EXPECT_CALL(visitor_, OnCapsule(expected_capsule1)); + EXPECT_CALL(visitor_, OnCapsule(expected_capsule2)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment1)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment2)); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment3)); + } + ValidateParserIsEmpty(); +} + +TEST_F(CapsuleTest, TwoCapsulesOneByteAtATime) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a5" // DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // capsule length + "a1a2a3a4a5a6a7a8" // HTTP Datagram payload + "80ff37a5" // DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // capsule length + "b1b2b3b4b5b6b7b8" // HTTP Datagram payload + ); + std::string datagram_payload1 = absl::HexStringToBytes("a1a2a3a4a5a6a7a8"); + std::string datagram_payload2 = absl::HexStringToBytes("b1b2b3b4b5b6b7b8"); + Capsule expected_capsule1 = + Capsule::DatagramWithoutContext(datagram_payload1); + Capsule expected_capsule2 = + Capsule::DatagramWithoutContext(datagram_payload2); + for (size_t i = 0; i < capsule_fragment.size(); i++) { + if (i < capsule_fragment.size() / 2 - 1) { + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + } else if (i == capsule_fragment.size() / 2 - 1) { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule1)); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); + } else if (i < capsule_fragment.size() - 1) { + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + } else { + EXPECT_CALL(visitor_, OnCapsule(expected_capsule2)); + ASSERT_TRUE( + capsule_parser_.IngestCapsuleFragment(capsule_fragment.substr(i, 1))); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); + } + } + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + EXPECT_TRUE(CapsuleParserPeer::buffered_data(&capsule_parser_)->empty()); +} + +TEST_F(CapsuleTest, PartialCapsuleThenError) { + std::string capsule_fragment = absl::HexStringToBytes( + "80ff37a5" // DATAGRAM_WITHOUT_CONTEXT capsule type + "08" // capsule length + "a1a2a3a4" // first half of HTTP Datagram payload + ); + EXPECT_CALL(visitor_, OnCapsule(_)).Times(0); + { + EXPECT_CALL(visitor_, OnCapsuleParseFailure(_)).Times(0); + ASSERT_TRUE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); + } + { + EXPECT_CALL(visitor_, + OnCapsuleParseFailure( + "Incomplete capsule left at the end of the stream")); + capsule_parser_.ErrorIfThereIsRemainingBufferedData(); + } +} + +TEST_F(CapsuleTest, RejectOverlyLongCapsule) { + std::string capsule_fragment = absl::HexStringToBytes( + "33" // unknown capsule type of 0x33 + "80123456" // capsule length + ) + + std::string(1111111, '?'); + EXPECT_CALL(visitor_, OnCapsuleParseFailure( + "Refusing to buffer too much capsule data")); + EXPECT_FALSE(capsule_parser_.IngestCapsuleFragment(capsule_fragment)); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/gquiche/quic/core/http/end_to_end_test.cc b/gquiche/quic/core/http/end_to_end_test.cc index 50804394..2966277d 100644 --- a/gquiche/quic/core/http/end_to_end_test.cc +++ b/gquiche/quic/core/http/end_to_end_test.cc @@ -14,9 +14,11 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "gquiche/quic/core/crypto/null_encrypter.h" +#include "gquiche/quic/core/crypto/quic_client_session_cache.h" #include "gquiche/quic/core/http/http_constants.h" #include "gquiche/quic/core/http/quic_spdy_client_stream.h" #include "gquiche/quic/core/http/web_transport_http3.h" +#include "gquiche/quic/core/quic_connection.h" #include "gquiche/quic/core/quic_data_writer.h" #include "gquiche/quic/core/quic_epoll_connection_helper.h" #include "gquiche/quic/core/quic_error_codes.h" @@ -45,6 +47,7 @@ #include "gquiche/quic/test_tools/qpack/qpack_encoder_test_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" #include "gquiche/quic/test_tools/quic_client_peer.h" +#include "gquiche/quic/test_tools/quic_client_session_cache_peer.h" #include "gquiche/quic/test_tools/quic_config_peer.h" #include "gquiche/quic/test_tools/quic_connection_peer.h" #include "gquiche/quic/test_tools/quic_dispatcher_peer.h" @@ -53,6 +56,7 @@ #include "gquiche/quic/test_tools/quic_server_peer.h" #include "gquiche/quic/test_tools/quic_session_peer.h" #include "gquiche/quic/test_tools/quic_spdy_session_peer.h" +#include "gquiche/quic/test_tools/quic_spdy_stream_peer.h" #include "gquiche/quic/test_tools/quic_stream_id_manager_peer.h" #include "gquiche/quic/test_tools/quic_stream_peer.h" #include "gquiche/quic/test_tools/quic_stream_sequencer_peer.h" @@ -62,7 +66,6 @@ #include "gquiche/quic/test_tools/quic_test_utils.h" #include "gquiche/quic/test_tools/quic_transport_test_tools.h" #include "gquiche/quic/test_tools/server_thread.h" -#include "gquiche/quic/test_tools/simple_session_cache.h" #include "gquiche/quic/tools/quic_backend_response.h" #include "gquiche/quic/tools/quic_client.h" #include "gquiche/quic/tools/quic_memory_cache_backend.h" @@ -79,7 +82,7 @@ using ::testing::_; using ::testing::Assign; using ::testing::Invoke; using ::testing::NiceMock; -using testing::NotNull; +using ::testing::UnorderedElementsAreArray; namespace quic { namespace test { @@ -135,7 +138,10 @@ void WriteHeadersOnStream(QuicSpdyStream* stream) { // Since QuicSpdyStream uses QuicHeaderList::empty() to detect too large // headers, it also fails when receiving empty headers. SpdyHeaderBlock headers; - headers["foo"] = "bar"; + headers[":authority"] = "test.example.com:443"; + headers[":path"] = "/path"; + headers[":method"] = "GET"; + headers[":scheme"] = "https"; stream->WriteHeaders(std::move(headers), /* fin = */ false, nullptr); } @@ -215,7 +221,7 @@ class EndToEndTest : public QuicTestWithParam { new QuicTestClient(server_address_, server_hostname_, client_config_, client_supported_versions_, crypto_test_utils::ProofVerifierForTesting(), - std::make_unique()); + std::make_unique()); client->SetUserAgentID(kTestUserAgentId); client->UseWriter(writer); if (!pre_shared_key_client_.empty()) { @@ -225,6 +231,7 @@ class EndToEndTest : public QuicTestWithParam { client->UseClientConnectionIdLength(override_client_connection_id_length_); client->client()->set_connection_debug_visitor(connection_debug_visitor_); client->client()->set_enable_web_transport(enable_web_transport_); + client->client()->set_use_datagram_contexts(use_datagram_contexts_); client->Connect(); return client; } @@ -362,9 +369,11 @@ class EndToEndTest : public QuicTestWithParam { bool Initialize() { if (enable_web_transport_) { - SetQuicReloadableFlag(quic_h3_datagram, true); memory_cache_backend_.set_enable_webtransport(true); } + if (use_datagram_contexts_) { + memory_cache_backend_.set_use_datagram_contexts(true); + } QuicTagVector copt; server_config_.SetConnectionOptionsToSend(copt); @@ -378,7 +387,10 @@ class EndToEndTest : public QuicTestWithParam { copt.push_back(kILD0); } copt.push_back(kPLE1); - copt.push_back(kRVCM); + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + copt.push_back(kRVCM); + } client_config_.SetConnectionOptionsToSend(copt); // Start the server first, because CreateQuicClient() attempts @@ -541,8 +553,12 @@ class EndToEndTest : public QuicTestWithParam { EXPECT_EQ(0u, server_stats.packets_lost); } EXPECT_EQ(0u, server_stats.packets_discarded); - EXPECT_EQ(server_session->user_agent_id().value_or("MissingUserAgent"), - kTestUserAgentId); + if (!GetQuicReloadableFlag( + quic_ignore_user_agent_transport_parameter)) { + EXPECT_EQ( + server_session->user_agent_id().value_or("MissingUserAgent"), + kTestUserAgentId); + } } else { ADD_FAILURE() << "Missing server connection"; } @@ -679,8 +695,9 @@ class EndToEndTest : public QuicTestWithParam { return WaitForFooResponseAndCheckIt(client_.get()); } - WebTransportHttp3* CreateWebTransportSession(const std::string& path, - bool wait_for_server_response) { + WebTransportHttp3* CreateWebTransportSession( + const std::string& path, bool wait_for_server_response, + QuicSpdyStream** connect_stream_out = nullptr) { // Wait until we receive the settings from the server indicating // WebTransport support. client_->WaitUntil( @@ -712,6 +729,9 @@ class EndToEndTest : public QuicTestWithParam { [stream]() { return stream->headers_decompressed(); }); EXPECT_TRUE(session->ready()); } + if (connect_stream_out != nullptr) { + *connect_stream_out = stream; + } return session; } @@ -724,21 +744,37 @@ class EndToEndTest : public QuicTestWithParam { } std::string ReadDataFromWebTransportStreamUntilFin( - WebTransportStream* stream) { + WebTransportStream* stream, MockStreamVisitor* visitor = nullptr) { + QuicStreamId id = stream->GetStreamId(); std::string buffer; + + // Try reading data if immediately available. + WebTransportStream::ReadResult result = stream->Read(&buffer); + if (result.fin) { + return buffer; + } + while (true) { bool can_read = false; - auto visitor = std::make_unique(); + if (visitor == nullptr) { + auto visitor_owned = std::make_unique(); + visitor = visitor_owned.get(); + stream->SetVisitor(std::move(visitor_owned)); + } EXPECT_CALL(*visitor, OnCanRead()).WillOnce(Assign(&can_read, true)); - stream->SetVisitor(std::move(visitor)); client_->WaitUntil(5000 /*ms*/, [&can_read]() { return can_read; }); if (!can_read) { - ADD_FAILURE() << "Waiting for readable data on stream " - << stream->GetStreamId() << " timed out"; + ADD_FAILURE() << "Waiting for readable data on stream " << id + << " timed out"; + return buffer; + } + if (GetClientSession()->GetOrCreateSpdyDataStream(id) == nullptr) { + ADD_FAILURE() << "Stream " << id + << " was deleted while waiting for incoming data"; return buffer; } - WebTransportStream::ReadResult result = stream->Read(&buffer); + result = stream->Read(&buffer); if (result.fin) { return buffer; } @@ -750,6 +786,31 @@ class EndToEndTest : public QuicTestWithParam { } } + void ReadAllIncomingWebTransportUnidirectionalStreams( + WebTransportSession* session) { + while (true) { + WebTransportStream* received_stream = + session->AcceptIncomingUnidirectionalStream(); + if (received_stream == nullptr) { + break; + } + received_webtransport_unidirectional_streams_.push_back( + ReadDataFromWebTransportStreamUntilFin(received_stream)); + } + } + + void WaitForNewConnectionIds() { + // Wait until a new server CID is available for another migration. + const auto* client_connection = GetClientConnection(); + while (!QuicConnectionPeer::HasUnusedPeerIssuedConnectionId( + client_connection) || + (!client_connection->client_connection_id().IsEmpty() && + !QuicConnectionPeer::HasSelfIssuedConnectionIdToConsume( + client_connection))) { + client_->client()->WaitForEvents(); + } + } + ScopedEnvironmentForThreads environment_; bool initialized_; // If true, the Initialize() function will create |client_| and starts to @@ -778,6 +839,8 @@ class EndToEndTest : public QuicTestWithParam { int override_client_connection_id_length_ = -1; uint8_t expected_server_connection_id_length_; bool enable_web_transport_ = false; + bool use_datagram_contexts_ = false; + std::vector received_webtransport_unidirectional_streams_; }; // Run all end to end tests with all supported versions. @@ -787,6 +850,8 @@ INSTANTIATE_TEST_SUITE_P(EndToEndTests, ::testing::PrintToStringParamName()); TEST_P(EndToEndTest, HandshakeSuccessful) { + SetQuicReloadableFlag(quic_delay_sequencer_buffer_allocation_until_new_data, + true); ASSERT_TRUE(Initialize()); EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); ASSERT_TRUE(server_thread_); @@ -842,6 +907,51 @@ TEST_P(EndToEndTest, HandshakeSuccessful) { server_thread_->Resume(); } +TEST_P(EndToEndTest, ExportKeyingMaterial) { + ASSERT_TRUE(Initialize()); + if (!version_.UsesTls()) { + return; + } + const char* kExportLabel = "label"; + const int kExportLen = 30; + std::string client_keying_material_export, server_keying_material_export; + + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(server_thread_); + server_thread_->WaitForCryptoHandshakeConfirmed(); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + QuicCryptoStream* server_crypto_stream = nullptr; + if (server_session != nullptr) { + server_crypto_stream = + QuicSessionPeer::GetMutableCryptoStream(server_session); + } else { + ADD_FAILURE() << "Missing server session"; + } + if (server_crypto_stream != nullptr) { + ASSERT_TRUE(server_crypto_stream->ExportKeyingMaterial( + kExportLabel, /*context=*/"", kExportLen, + &server_keying_material_export)); + + } else { + ADD_FAILURE() << "Missing server crypto stream"; + } + server_thread_->Resume(); + + QuicSpdyClientSession* client_session = GetClientSession(); + ASSERT_TRUE(client_session); + QuicCryptoStream* client_crypto_stream = + QuicSessionPeer::GetMutableCryptoStream(client_session); + ASSERT_TRUE(client_crypto_stream); + ASSERT_TRUE(client_crypto_stream->ExportKeyingMaterial( + kExportLabel, /*context=*/"", kExportLen, + &client_keying_material_export)); + ASSERT_EQ(client_keying_material_export.size(), + static_cast(kExportLen)); + EXPECT_EQ(client_keying_material_export, server_keying_material_export); +} + TEST_P(EndToEndTest, SimpleRequestResponse) { ASSERT_TRUE(Initialize()); @@ -1523,6 +1633,7 @@ TEST_P(EndToEndTest, LargePostNoPacketLossWithDelayAndReordering) { } TEST_P(EndToEndTest, AddressToken) { + client_extra_copts_.push_back(kTRTT); ASSERT_TRUE(Initialize()); if (!version_.HasIetfQuicFrames()) { return; @@ -1540,7 +1651,7 @@ TEST_P(EndToEndTest, AddressToken) { // The 0-RTT handshake should succeed. client_->Connect(); - EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + EXPECT_TRUE(client_->client()->WaitForHandshakeConfirmed()); ASSERT_TRUE(client_->client()->connected()); SendSynchronousFooRequestAndCheckResponse(); @@ -1550,26 +1661,217 @@ TEST_P(EndToEndTest, AddressToken) { EXPECT_TRUE(client_->client()->EarlyDataAccepted()); server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); QuicConnection* server_connection = GetServerConnection(); - if (server_connection != nullptr) { - if (GetQuicReloadableFlag(quic_enable_token_based_address_validation)) { + if (server_session != nullptr && server_connection != nullptr) { + // Verify address is validated via validating token received in INITIAL + // packet. + EXPECT_FALSE( + server_connection->GetStats().address_validated_via_decrypting_packet); + EXPECT_TRUE(server_connection->GetStats().address_validated_via_token); + + // Verify the server received a cached min_rtt from the token and used it as + // the initial rtt. + const CachedNetworkParameters* server_received_network_params = + static_cast( + server_session->GetCryptoStream()) + ->PreviousCachedNetworkParams(); + if (GetQuicReloadableFlag( + quic_add_cached_network_parameters_to_address_token2)) { + ASSERT_NE(server_received_network_params, nullptr); + // QuicSentPacketManager::SetInitialRtt clamps the initial_rtt to between + // [min_initial_rtt, max_initial_rtt]. + const QuicTime::Delta min_initial_rtt = + QuicTime::Delta::FromMicroseconds(kMinInitialRoundTripTimeUs); + const QuicTime::Delta max_initial_rtt = + QuicTime::Delta::FromMicroseconds(kMaxInitialRoundTripTimeUs); + const QuicTime::Delta expected_initial_rtt = + std::max(min_initial_rtt, + std::min(max_initial_rtt, + QuicTime::Delta::FromMilliseconds( + server_received_network_params->min_rtt_ms()))); + EXPECT_EQ( + server_connection->sent_packet_manager().GetRttStats()->initial_rtt(), + expected_initial_rtt); + } else { + EXPECT_EQ(server_received_network_params, nullptr); + } + } else { + ADD_FAILURE() << "Missing server connection"; + } + + server_thread_->Resume(); + + client_->Disconnect(); + + // Regression test for b/206087883. + // Mock server crash. + StopServer(); + + // The handshake fails due to idle timeout. + client_->Connect(); + ASSERT_FALSE(client_->client()->WaitForOneRttKeysAvailable()); + client_->WaitForWriteToFlush(); + client_->WaitForResponse(); + ASSERT_FALSE(client_->client()->connected()); + EXPECT_THAT(client_->connection_error(), IsError(QUIC_NETWORK_IDLE_TIMEOUT)); + + // Server restarts. + server_writer_ = new PacketDroppingTestWriter(); + StartServer(); + + // Client re-connect. + client_->Connect(); + ASSERT_TRUE(client_->client()->WaitForHandshakeConfirmed()); + client_->WaitForWriteToFlush(); + client_->WaitForResponse(); + ASSERT_TRUE(client_->client()->connected()); + client_session = GetClientSession(); + ASSERT_TRUE(client_session); + EXPECT_FALSE(client_session->EarlyDataAccepted()); + EXPECT_FALSE(client_->client()->EarlyDataAccepted()); + server_thread_->Pause(); + server_session = GetServerSession(); + server_connection = GetServerConnection(); + if (!GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { + // Address token is reused. + if (server_session != nullptr && server_connection != nullptr) { // Verify address is validated via validating token received in INITIAL // packet. EXPECT_FALSE(server_connection->GetStats() .address_validated_via_decrypting_packet); EXPECT_TRUE(server_connection->GetStats().address_validated_via_token); } else { + ADD_FAILURE() << "Missing server connection"; + } + } else { + // Verify address token is only used once. + if (server_session != nullptr && server_connection != nullptr) { + // Verify address is validated via decrypting packet. EXPECT_TRUE(server_connection->GetStats() .address_validated_via_decrypting_packet); EXPECT_FALSE(server_connection->GetStats().address_validated_via_token); + } else { + ADD_FAILURE() << "Missing server connection"; } + } + server_thread_->Resume(); + + client_->Disconnect(); +} + +TEST_P(EndToEndTest, AddressTokenRefreshedByServer) { + SetQuicReloadableFlag(quic_add_cached_network_parameters_to_address_token2, + true); + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + return; + } + + QuicCryptoClientConfig* client_crypto_config = + client_->client()->crypto_config(); + QuicServerId server_id = client_->client()->server_id(); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_FALSE(GetClientSession()->EarlyDataAccepted()); + + client_->Disconnect(); + + QuicClientSessionCache* session_cache = static_cast( + client_crypto_config->mutable_session_cache()); + std::string old_address_token; + if (GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { + old_address_token = + QuicClientSessionCachePeer::GetToken(session_cache, server_id); } else { - ADD_FAILURE() << "Missing server connection"; + old_address_token = + client_crypto_config->LookupOrCreate(server_id)->source_address_token(); } + ASSERT_TRUE(!old_address_token.empty()); + + SetQuicReloadableFlag(quic_add_cached_network_parameters_to_address_token2, + false); + + // The 0-RTT handshake should succeed. + client_->Connect(); + EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); + ASSERT_TRUE(client_->client()->connected()); + SendSynchronousFooRequestAndCheckResponse(); + + EXPECT_TRUE(GetClientSession()->EarlyDataAccepted()); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + QuicConnection* server_connection = GetServerConnection(); + ASSERT_TRUE(server_session != nullptr && server_connection != nullptr); + // Verify address is validated via validating token received in INITIAL + // packet. + EXPECT_FALSE( + server_connection->GetStats().address_validated_via_decrypting_packet); + EXPECT_TRUE(server_connection->GetStats().address_validated_via_token); server_thread_->Resume(); client_->Disconnect(); + + std::string new_address_token; + if (GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { + new_address_token = + QuicClientSessionCachePeer::GetToken(session_cache, server_id); + } else { + new_address_token = + client_crypto_config->LookupOrCreate(server_id)->source_address_token(); + } + ASSERT_TRUE(!new_address_token.empty()); + ASSERT_NE(new_address_token, old_address_token); +} + +// Verify that client does not reuse a source address token. +TEST_P(EndToEndTest, AddressTokenNotReusedByClient) { + ASSERT_TRUE(Initialize()); + if (!version_.HasIetfQuicFrames()) { + return; + } + + QuicCryptoClientConfig* client_crypto_config = + client_->client()->crypto_config(); + QuicServerId server_id = client_->client()->server_id(); + + SendSynchronousFooRequestAndCheckResponse(); + EXPECT_FALSE(GetClientSession()->EarlyDataAccepted()); + + client_->Disconnect(); + + QuicClientSessionCache* session_cache = static_cast( + client_crypto_config->mutable_session_cache()); + std::string old_address_token; + if (GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { + old_address_token = + QuicClientSessionCachePeer::GetToken(session_cache, server_id); + } else { + old_address_token = + client_crypto_config->LookupOrCreate(server_id)->source_address_token(); + } + ASSERT_TRUE(!old_address_token.empty()); + + // Pause the server thread again to blackhole packets from client. + server_thread_->Pause(); + client_->Connect(); + EXPECT_FALSE(client_->client()->WaitForOneRttKeysAvailable()); + EXPECT_FALSE(client_->client()->connected()); + + std::string new_address_token; + if (GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { + new_address_token = + QuicClientSessionCachePeer::GetToken(session_cache, server_id); + // Verify address token gets cleared. + ASSERT_TRUE(new_address_token.empty()); + } else { + new_address_token = + client_crypto_config->LookupOrCreate(server_id)->source_address_token(); + ASSERT_FALSE(new_address_token.empty()); + } + server_thread_->Resume(); } TEST_P(EndToEndTest, LargePostZeroRTTFailure) { @@ -1876,11 +2178,6 @@ TEST_P(EndToEndTest, RetransmissionAfterZeroRTTRejectBeforeOneRtt) { ON_CALL(visitor, OnZeroRttRejected(_)).WillByDefault(Invoke([this]() { EXPECT_FALSE(GetClientSession()->IsEncryptionEstablished()); - if (!GetQuicReloadableFlag(quic_donot_write_mid_packet_processing)) { - // Trigger an OnCanWrite() to make sure no unencrypted data will be - // written. - GetClientSession()->OnCanWrite(); - } })); // The 0-RTT handshake should fail. @@ -2477,10 +2774,7 @@ TEST_P( HalfRttResponseBlocksShloRetransmissionWithoutTokenBasedAddressValidation) { // Turn off token based address validation to make the server get constrained // by amplification factor during handshake. - // TODO(fayang): Keep this test while deprecating - // quic_enable_token_based_address_validation. For example, consider always - // rejecting the received address token. - SetQuicReloadableFlag(quic_enable_token_based_address_validation, false); + SetQuicFlag(FLAGS_quic_reject_retry_token_in_initial_packet, true); ASSERT_TRUE(Initialize()); if (!version_.SupportsAntiAmplificationLimit()) { return; @@ -2502,17 +2796,8 @@ TEST_P( // Large response (100KB) for 0-RTT request. std::string large_body(102400, 'a'); AddToCache("/large_response", 200, large_body); - if (GetQuicReloadableFlag(quic_preempt_stream_data_with_handshake_packet)) { - SendSynchronousRequestAndCheckResponse(client_.get(), "/large_response", - large_body); - } else { - // Server consistently gets constrained by amplification factor, hence PTO - // never gets armed. The CHLO retransmission would trigger the - // retransmission of SHLO, however, the ENCRYPTION_HANDSHAKE packet NEVER - // gets retransmitted since half RTT data consumes the remaining space in - // the coalescer. - EXPECT_EQ("", client_->SendSynchronousRequest("/large_response")); - } + SendSynchronousRequestAndCheckResponse(client_.get(), "/large_response", + large_body); } TEST_P(EndToEndTest, MaxStreamsUberTest) { @@ -2620,52 +2905,250 @@ TEST_P(EndToEndTest, ConnectionMigrationClientIPChanged) { server_thread_->Resume(); } -TEST_P(EndToEndTest, ConnectionMigrationClientIPChangedWithNonEmptyClientCID) { +TEST_P(EndToEndTest, IetfConnectionMigrationClientIPChangedMultipleTimes) { + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + SendSynchronousFooRequestAndCheckResponse(); + + // Store the client IP address which was used to send the first request. + QuicIpAddress host0 = + client_->client()->network_helper()->GetLatestClientAddress().host(); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection != nullptr); + + // Migrate socket to a new IP address. + QuicIpAddress host1 = TestLoopback(2); + EXPECT_NE(host0, host1); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + QuicConnectionId server_cid0 = client_connection->connection_id(); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid1 = client_connection->connection_id(); + EXPECT_FALSE(server_cid1.IsEmpty()); + EXPECT_NE(server_cid0, server_cid1); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + + // Send another request and wait for response making sure path response is + // received at server. + SendSynchronousBarRequestAndCheckResponse(); + + // Migrate socket to a new IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + QuicIpAddress host2 = TestLoopback(3); + EXPECT_NE(host0, host2); + EXPECT_NE(host1, host2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host2)); + QuicConnectionId server_cid2 = client_connection->connection_id(); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid0, server_cid2); + EXPECT_NE(server_cid1, server_cid2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request using the new socket and wait for response making sure + // path response is received at server. + SendSynchronousBarRequestAndCheckResponse(); + EXPECT_EQ(2u, + client_connection->GetStats().num_connectivity_probing_received); + + // Migrate socket back to an old IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid3 = client_connection->connection_id(); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid0, server_cid3); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + const auto* client_packet_creator = + QuicConnectionPeer::GetPacketCreator(client_connection); + EXPECT_TRUE(client_packet_creator->GetClientConnectionId().IsEmpty()); + EXPECT_EQ(server_cid3, client_packet_creator->GetServerConnectionId()); + + // Send another request using the new socket and wait for response making sure + // path response is received at server. + SendSynchronousBarRequestAndCheckResponse(); + // Even this is an old path, server has forgotten about it and thus needs to + // validate the path again. + EXPECT_EQ(3u, + client_connection->GetStats().num_connectivity_probing_received); + + WaitForNewConnectionIds(); + EXPECT_EQ(3u, client_connection->GetStats().num_retire_connection_id_sent); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + // By the time the 2nd request is completed, the PATH_RESPONSE must have been + // received by the server. + EXPECT_FALSE(server_connection->HasPendingPathValidation()); + EXPECT_EQ(3u, server_connection->GetStats().num_validated_peer_migration); + EXPECT_EQ(server_cid3, server_connection->connection_id()); + const auto* server_packet_creator = + QuicConnectionPeer::GetPacketCreator(server_connection); + EXPECT_EQ(server_cid3, server_packet_creator->GetServerConnectionId()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + server_connection) + .IsEmpty()); + EXPECT_EQ(4u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, + ConnectionMigrationWithNonZeroConnectionIDClientIPChangedMultipleTimes) { if (!version_.SupportsClientConnectionIds()) { ASSERT_TRUE(Initialize()); return; } override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; ASSERT_TRUE(Initialize()); - if (!version_.HasIetfQuicFrames() || - !client_->client()->session()->connection()->validate_client_address()) { + if (!GetClientConnection()->connection_migration_use_new_cid()) { return; } SendSynchronousFooRequestAndCheckResponse(); // Store the client IP address which was used to send the first request. - QuicIpAddress old_host = + QuicIpAddress host0 = client_->client()->network_helper()->GetLatestClientAddress().host(); - auto* client_connection = GetClientConnection(); - QuicConnectionId client_cid0 = client_connection->client_connection_id(); - QuicConnectionId server_cid0 = client_connection->connection_id(); - - // Migrate socket to the new IP address. - QuicIpAddress new_host = TestLoopback(2); - EXPECT_NE(old_host, new_host); - ASSERT_TRUE(client_->client()->MigrateSocket(new_host)); + QuicConnection* client_connection = GetClientConnection(); + ASSERT_TRUE(client_connection != nullptr); - // Send a request using the new socket. + // Migrate socket to a new IP address. + QuicIpAddress host1 = TestLoopback(2); + EXPECT_NE(host0, host1); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + QuicConnectionId server_cid0 = client_connection->connection_id(); + QuicConnectionId client_cid0 = client_connection->client_connection_id(); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid1 = client_connection->connection_id(); + QuicConnectionId client_cid1 = client_connection->client_connection_id(); + EXPECT_FALSE(server_cid1.IsEmpty()); + EXPECT_FALSE(client_cid1.IsEmpty()); + EXPECT_NE(server_cid0, server_cid1); + EXPECT_NE(client_cid0, client_cid1); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. SendSynchronousBarRequestAndCheckResponse(); - EXPECT_EQ(1u, client_connection->GetStats().num_connectivity_probing_received); - // Send another request. + // Migrate socket to a new IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, client_connection->GetStats().num_new_connection_id_sent); + QuicIpAddress host2 = TestLoopback(3); + EXPECT_NE(host0, host2); + EXPECT_NE(host1, host2); + EXPECT_TRUE(client_->client()->MigrateSocket(host2)); + QuicConnectionId server_cid2 = client_connection->connection_id(); + QuicConnectionId client_cid2 = client_connection->client_connection_id(); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid0, server_cid2); + EXPECT_NE(server_cid1, server_cid2); + EXPECT_FALSE(client_cid2.IsEmpty()); + EXPECT_NE(client_cid0, client_cid2); + EXPECT_NE(client_cid1, client_cid2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_TRUE(QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. + SendSynchronousBarRequestAndCheckResponse(); + EXPECT_EQ(2u, + client_connection->GetStats().num_connectivity_probing_received); + + // Migrate socket back to an old IP address. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(3u, client_connection->GetStats().num_new_connection_id_sent); + EXPECT_TRUE(client_->client()->MigrateSocket(host1)); + QuicConnectionId server_cid3 = client_connection->connection_id(); + QuicConnectionId client_cid3 = client_connection->client_connection_id(); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid0, server_cid3); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + EXPECT_FALSE(client_cid3.IsEmpty()); + EXPECT_NE(client_cid0, client_cid3); + EXPECT_NE(client_cid1, client_cid3); + EXPECT_NE(client_cid2, client_cid3); + const auto* client_packet_creator = + QuicConnectionPeer::GetPacketCreator(client_connection); + EXPECT_EQ(client_cid3, client_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid3, client_packet_creator->GetServerConnectionId()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send another request to ensure that the server will have time to finish the + // reverse path validation and send address token. SendSynchronousBarRequestAndCheckResponse(); + // Even this is an old path, server has forgotten about it and thus needs to + // validate the path again. + EXPECT_EQ(3u, + client_connection->GetStats().num_connectivity_probing_received); - EXPECT_EQ(client_cid0, client_connection->client_connection_id()); - EXPECT_EQ(server_cid0, client_connection->connection_id()); + WaitForNewConnectionIds(); + EXPECT_EQ(3u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(4u, client_connection->GetStats().num_new_connection_id_sent); + server_thread_->Pause(); // By the time the 2nd request is completed, the PATH_RESPONSE must have been // received by the server. - server_thread_->Pause(); QuicConnection* server_connection = GetServerConnection(); - ASSERT_THAT(server_connection, NotNull()); EXPECT_FALSE(server_connection->HasPendingPathValidation()); - EXPECT_EQ(1u, server_connection->GetStats().num_validated_peer_migration); - EXPECT_EQ(client_cid0, server_connection->client_connection_id()); - EXPECT_EQ(server_cid0, server_connection->connection_id()); + EXPECT_EQ(3u, server_connection->GetStats().num_validated_peer_migration); + EXPECT_EQ(server_cid3, server_connection->connection_id()); + EXPECT_EQ(client_cid3, server_connection->client_connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + server_connection) + .IsEmpty()); + const auto* server_packet_creator = + QuicConnectionPeer::GetPacketCreator(server_connection); + EXPECT_EQ(client_cid3, server_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid3, server_packet_creator->GetServerConnectionId()); + EXPECT_EQ(3u, server_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(4u, server_connection->GetStats().num_new_connection_id_sent); server_thread_->Resume(); } @@ -2693,7 +3176,7 @@ TEST_P(EndToEndTest, ConnectionMigrationNewTokenForNewIp) { EXPECT_EQ(1u, client_connection->GetStats().num_connectivity_probing_received); - // Send another request to ensure that the server will time to finish the + // Send another request to ensure that the server will have time to finish the // reverse path validation and send address token. SendSynchronousBarRequestAndCheckResponse(); @@ -2710,17 +3193,11 @@ TEST_P(EndToEndTest, ConnectionMigrationNewTokenForNewIp) { server_thread_->Pause(); QuicConnection* server_connection = GetServerConnection(); if (server_connection != nullptr) { - if (GetQuicReloadableFlag(quic_enable_token_based_address_validation)) { - // Verify address is validated via validating token received in INITIAL - // packet. - EXPECT_FALSE(server_connection->GetStats() - .address_validated_via_decrypting_packet); - EXPECT_TRUE(server_connection->GetStats().address_validated_via_token); - } else { - EXPECT_TRUE(server_connection->GetStats() - .address_validated_via_decrypting_packet); - EXPECT_FALSE(server_connection->GetStats().address_validated_via_token); - } + // Verify address is validated via validating token received in INITIAL + // packet. + EXPECT_FALSE( + server_connection->GetStats().address_validated_via_decrypting_packet); + EXPECT_TRUE(server_connection->GetStats().address_validated_via_token); } else { ADD_FAILURE() << "Missing server connection"; } @@ -2758,14 +3235,19 @@ class DuplicatePacketWithSpoofedSelfAddressWriter TEST_P(EndToEndTest, ClientAddressSpoofedForSomePeriod) { ASSERT_TRUE(Initialize()); - if (!version_.HasIetfQuicFrames() || - !client_->client()->session()->connection()->validate_client_address()) { + if (!GetClientConnection()->connection_migration_use_new_cid()) { return; } auto writer = new DuplicatePacketWithSpoofedSelfAddressWriter(); client_.reset(CreateQuicClient(writer)); + + // Make sure client has unused peer connection ID before migration. + SendSynchronousFooRequestAndCheckResponse(); + ASSERT_TRUE(QuicConnectionPeer::HasUnusedPeerIssuedConnectionId( + GetClientConnection())); + QuicIpAddress real_host = TestLoopback(1); - client_->MigrateSocket(real_host); + ASSERT_TRUE(client_->MigrateSocket(real_host)); SendSynchronousFooRequestAndCheckResponse(); EXPECT_EQ( 0u, GetClientConnection()->GetStats().num_connectivity_probing_received); @@ -2804,10 +3286,10 @@ TEST_P(EndToEndTest, ClientAddressSpoofedForSomePeriod) { EXPECT_EQ(large_body, client_->response_body()); } -TEST_P(EndToEndTest, AsynchronousConnectionMigrationClientIPChanged) { +TEST_P(EndToEndTest, + AsynchronousConnectionMigrationClientIPChangedMultipleTimes) { ASSERT_TRUE(Initialize()); - if (!version_.HasIetfQuicFrames() || - !client_->client()->session()->connection()->use_path_validator()) { + if (!GetClientConnection()->connection_migration_use_new_cid()) { return; } client_.reset(CreateQuicClient(nullptr)); @@ -2815,36 +3297,96 @@ TEST_P(EndToEndTest, AsynchronousConnectionMigrationClientIPChanged) { SendSynchronousFooRequestAndCheckResponse(); // Store the client IP address which was used to send the first request. - QuicIpAddress old_host = + QuicIpAddress host0 = client_->client()->network_helper()->GetLatestClientAddress().host(); + QuicConnection* client_connection = GetClientConnection(); + QuicConnectionId server_cid0 = client_connection->connection_id(); + // Server should have one new connection ID upon handshake completion. + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + + // Migrate socket to new IP address #1. + QuicIpAddress host1 = TestLoopback(2); + EXPECT_NE(host0, host1); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host1)); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host1, client_->client()->session()->self_address().host()); + EXPECT_EQ(1u, + client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId server_cid1 = client_connection->connection_id(); + EXPECT_NE(server_cid0, server_cid1); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); - // Migrate socket to the new IP address. - QuicIpAddress new_host = TestLoopback(2); - EXPECT_NE(old_host, new_host); - ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(new_host)); + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + + // Migrate socket to new IP address #2. + WaitForNewConnectionIds(); + QuicIpAddress host2 = TestLoopback(3); + EXPECT_NE(host0, host1); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host2)); while (client_->client()->HasPendingPathValidation()) { client_->client()->WaitForEvents(); } - EXPECT_EQ(new_host, client_->client()->session()->self_address().host()); - QuicConnection* client_connection = GetClientConnection(); - ASSERT_TRUE(client_connection); - EXPECT_EQ(client_connection->validate_client_address() ? 1u : 0, + EXPECT_EQ(host2, client_->client()->session()->self_address().host()); + EXPECT_EQ(2u, client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId server_cid2 = client_connection->connection_id(); + EXPECT_NE(server_cid0, server_cid2); + EXPECT_NE(server_cid1, server_cid2); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + // Send a request using the new socket. SendSynchronousBarRequestAndCheckResponse(); -} -TEST_P(EndToEndTest, - AsynchronousConnectionMigrationClientIPChangedWithNonEmptyClientCID) { - if (!version_.SupportsClientConnectionIds()) { - ASSERT_TRUE(Initialize()); - return; - } + // Migrate socket back to IP address #1. + WaitForNewConnectionIds(); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host1)); + + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host1, client_->client()->session()->self_address().host()); + EXPECT_EQ(3u, + client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId server_cid3 = client_connection->connection_id(); + EXPECT_NE(server_cid0, server_cid3); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + + // Send a request using the new socket. + SendSynchronousBarRequestAndCheckResponse(); + server_thread_->Pause(); + const QuicConnection* server_connection = GetServerConnection(); + EXPECT_EQ(server_connection->connection_id(), server_cid3); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + server_connection) + .IsEmpty()); + server_thread_->Resume(); + + // There should be 1 new connection ID issued by the server. + WaitForNewConnectionIds(); +} + +TEST_P(EndToEndTest, + AsynchronousConnectionMigrationClientIPChangedWithNonEmptyClientCID) { + if (!version_.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; ASSERT_TRUE(Initialize()); - if (!version_.HasIetfQuicFrames() || - !client_->client()->session()->connection()->validate_client_address()) { + if (!GetClientConnection()->connection_migration_use_new_cid()) { return; } client_.reset(CreateQuicClient(nullptr)); @@ -2869,13 +3411,23 @@ TEST_P(EndToEndTest, EXPECT_EQ(new_host, client_->client()->session()->self_address().host()); EXPECT_EQ(1u, client_connection->GetStats().num_connectivity_probing_received); + QuicConnectionId client_cid1 = client_connection->client_connection_id(); + QuicConnectionId server_cid1 = client_connection->connection_id(); + const auto* client_packet_creator = + QuicConnectionPeer::GetPacketCreator(client_connection); + EXPECT_EQ(client_cid1, client_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid1, client_packet_creator->GetServerConnectionId()); // Send a request using the new socket. SendSynchronousBarRequestAndCheckResponse(); server_thread_->Pause(); QuicConnection* server_connection = GetServerConnection(); - EXPECT_EQ(client_cid0, server_connection->client_connection_id()); - EXPECT_EQ(server_cid0, server_connection->connection_id()); + EXPECT_EQ(client_cid1, server_connection->client_connection_id()); + EXPECT_EQ(server_cid1, server_connection->connection_id()); + const auto* server_packet_creator = + QuicConnectionPeer::GetPacketCreator(server_connection); + EXPECT_EQ(client_cid1, server_packet_creator->GetClientConnectionId()); + EXPECT_EQ(server_cid1, server_packet_creator->GetServerConnectionId()); server_thread_->Resume(); } @@ -3582,6 +4134,112 @@ TEST_P(EndToEndTest, ServerSendVersionNegotiationWithDifferentConnectionId) { client_connection->set_debug_visitor(nullptr); } +// DowngradePacketWriter is a client writer which will intercept all the client +// writes for |target_version| and reply to them with version negotiation +// packets to attempt a version downgrade attack. Once the client has downgraded +// to a different version, the writer stops intercepting. |server_thread| must +// start off paused, and will be resumed once interception is done. +class DowngradePacketWriter : public PacketDroppingTestWriter { + public: + explicit DowngradePacketWriter( + const ParsedQuicVersion& target_version, + const ParsedQuicVersionVector& supported_versions, QuicTestClient* client, + QuicPacketWriter* server_writer, ServerThread* server_thread) + : target_version_(target_version), + supported_versions_(supported_versions), + client_(client), + server_writer_(server_writer), + server_thread_(server_thread) {} + ~DowngradePacketWriter() override {} + + WriteResult WritePacket(const char* buffer, size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + quic::PerPacketOptions* options) override { + if (!intercept_enabled_) { + return PacketDroppingTestWriter::WritePacket( + buffer, buf_len, self_address, peer_address, options); + } + PacketHeaderFormat format; + QuicLongHeaderType long_packet_type; + bool version_present, has_length_prefix; + QuicVersionLabel version_label; + ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported(); + QuicConnectionId destination_connection_id, source_connection_id; + absl::optional retry_token; + std::string detailed_error; + if (QuicFramer::ParsePublicHeaderDispatcher( + QuicEncryptedPacket(buffer, buf_len), + kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_present, &has_length_prefix, &version_label, + &parsed_version, &destination_connection_id, &source_connection_id, + &retry_token, &detailed_error) != QUIC_NO_ERROR) { + ADD_FAILURE() << "Failed to parse our own packet: " << detailed_error; + return WriteResult(WRITE_STATUS_ERROR, 0); + } + if (!version_present || parsed_version != target_version_) { + // Client is sending with another version, the attack has succeeded so we + // can stop intercepting. + intercept_enabled_ = false; + server_thread_->Resume(); + // Pass the client-sent packet through. + return WritePacket(buffer, buf_len, self_address, peer_address, options); + } + // Send a version negotiation packet. + std::unique_ptr packet( + QuicFramer::BuildVersionNegotiationPacket( + destination_connection_id, source_connection_id, + parsed_version.HasIetfInvariantHeader(), has_length_prefix, + supported_versions_)); + server_writer_->WritePacket( + packet->data(), packet->length(), peer_address.host(), + client_->client()->network_helper()->GetLatestClientAddress(), nullptr); + // Drop the client-sent packet but pretend it was sent. + return WriteResult(WRITE_STATUS_OK, buf_len); + } + + private: + bool intercept_enabled_ = true; + ParsedQuicVersion target_version_; + ParsedQuicVersionVector supported_versions_; + QuicTestClient* client_; // Unowned. + QuicPacketWriter* server_writer_; // Unowned. + ServerThread* server_thread_; // Unowned. +}; + +TEST_P(EndToEndTest, VersionNegotiationDowngradeAttackIsDetected) { + ParsedQuicVersion target_version = server_supported_versions_.back(); + if (!version_.UsesTls() || target_version == version_) { + ASSERT_TRUE(Initialize()); + return; + } + SetQuicReloadableFlag(quic_version_information, true); + connect_to_server_on_initialize_ = false; + client_supported_versions_.insert(client_supported_versions_.begin(), + target_version); + ParsedQuicVersionVector downgrade_versions{version_}; + ASSERT_TRUE(Initialize()); + ASSERT_TRUE(server_thread_); + // Pause the server thread to allow our DowngradePacketWriter to write version + // negotiation packets in a thread-safe manner. It will be resumed by the + // DowngradePacketWriter. + server_thread_->Pause(); + client_.reset(new QuicTestClient(server_address_, server_hostname_, + client_config_, client_supported_versions_, + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique())); + delete client_writer_; + client_writer_ = new DowngradePacketWriter(target_version, downgrade_versions, + client_.get(), server_writer_, + server_thread_.get()); + client_->UseWriter(client_writer_); + // Have the client attempt to send a request. + client_->Connect(); + EXPECT_TRUE(client_->SendSynchronousRequest("/foo").empty()); + // Make sure the downgrade is detected and the handshake fails. + EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_FAILED)); +} + // A bad header shouldn't tear down the connection, because the receiver can't // tell the connection ID. TEST_P(EndToEndTest, BadPacketHeaderTruncated) { @@ -3620,7 +4278,7 @@ TEST_P(EndToEndTest, BadPacketHeaderFlags) { SendSynchronousFooRequestAndCheckResponse(); // Packet with invalid public flags. - char packet[] = { + uint8_t packet[] = { // invalid public flags 0xFF, // connection_id @@ -3643,7 +4301,7 @@ TEST_P(EndToEndTest, BadPacketHeaderFlags) { 0x00, }; client_writer_->WritePacket( - &packet[0], sizeof(packet), + reinterpret_cast(packet), sizeof(packet), client_->client()->network_helper()->GetLatestClientAddress().host(), server_address_, nullptr); @@ -3973,312 +4631,6 @@ TEST_P(EndToEndTest, Trailers) { EXPECT_EQ(trailers, client_->response_trailers()); } -class EndToEndTestServerPush : public EndToEndTest { - protected: - const size_t kNumMaxStreams = 10; - - EndToEndTestServerPush() : EndToEndTest() { - client_config_.SetMaxBidirectionalStreamsToSend(kNumMaxStreams); - server_config_.SetMaxBidirectionalStreamsToSend(kNumMaxStreams); - client_config_.SetMaxUnidirectionalStreamsToSend(kNumMaxStreams); - server_config_.SetMaxUnidirectionalStreamsToSend(kNumMaxStreams); - } - - // Add a request with its response and |num_resources| push resources into - // cache. - // If |resource_size| == 0, response body of push resources use default string - // concatenating with resource url. Otherwise, generate a string of - // |resource_size| as body. - void AddRequestAndResponseWithServerPush(std::string host, - std::string path, - std::string response_body, - std::string* push_urls, - const size_t num_resources, - const size_t resource_size) { - bool use_large_response = resource_size != 0; - std::string large_resource; - if (use_large_response) { - // Generate a response common body larger than flow control window for - // push response. - large_resource = std::string(resource_size, 'a'); - } - std::list push_resources; - for (size_t i = 0; i < num_resources; ++i) { - std::string url = push_urls[i]; - QuicUrl resource_url(url); - std::string body = - use_large_response - ? large_resource - : absl::StrCat("This is server push response body for ", url); - SpdyHeaderBlock response_headers; - response_headers[":status"] = "200"; - response_headers["content-length"] = absl::StrCat(body.size()); - push_resources.push_back(QuicBackendResponse::ServerPushInfo( - resource_url, std::move(response_headers), kV3LowestPriority, body)); - } - - memory_cache_backend_.AddSimpleResponseWithServerPushResources( - host, path, 200, response_body, push_resources); - } -}; - -// Run all server push end to end tests with all supported versions. -INSTANTIATE_TEST_SUITE_P(EndToEndTestsServerPush, - EndToEndTestServerPush, - ::testing::ValuesIn(GetTestParams()), - ::testing::PrintToStringParamName()); - -TEST_P(EndToEndTestServerPush, ServerPush) { - ASSERT_TRUE(Initialize()); - EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); - - // Set reordering to ensure that body arriving before PUSH_PROMISE is ok. - SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(2)); - SetReorderPercentage(30); - - // Add a response with headers, body, and push resources. - const std::string kBody = "body content"; - size_t kNumResources = 4; - std::string push_urls[] = {"https://example.com/font.woff", - "https://example.com/script.js", - "https://fonts.example.com/font.woff", - "https://example.com/logo-hires.jpg"}; - AddRequestAndResponseWithServerPush("example.com", "/push_example", kBody, - push_urls, kNumResources, 0); - - client_->client()->set_response_listener( - std::unique_ptr( - new TestResponseListener)); - - QUIC_DVLOG(1) << "send request for /push_example"; - EXPECT_EQ(kBody, client_->SendSynchronousRequest( - "https://example.com/push_example")); - QuicStreamSequencer* sequencer = nullptr; - if (!version_.UsesHttp3()) { - QuicSpdyClientSession* client_session = GetClientSession(); - ASSERT_TRUE(client_session); - QuicHeadersStream* headers_stream = - QuicSpdySessionPeer::GetHeadersStream(client_session); - ASSERT_TRUE(headers_stream); - sequencer = QuicStreamPeer::sequencer(headers_stream); - ASSERT_TRUE(sequencer); - // Headers stream's sequencer buffer shouldn't be released because server - // push hasn't finished yet. - EXPECT_TRUE( - QuicStreamSequencerPeer::IsUnderlyingBufferAllocated(sequencer)); - } - - for (const std::string& url : push_urls) { - QUIC_DVLOG(1) << "send request for pushed stream on url " << url; - std::string expected_body = - absl::StrCat("This is server push response body for ", url); - std::string response_body = client_->SendSynchronousRequest(url); - QUIC_DVLOG(1) << "response body " << response_body; - EXPECT_EQ(expected_body, response_body); - } - if (!version_.UsesHttp3()) { - ASSERT_TRUE(sequencer); - EXPECT_FALSE( - QuicStreamSequencerPeer::IsUnderlyingBufferAllocated(sequencer)); - } -} - -TEST_P(EndToEndTestServerPush, ServerPushUnderLimit) { - // Tests that sending a request which has 4 push resources will trigger server - // to push those 4 resources and client can handle pushed resources and match - // them with requests later. - ASSERT_TRUE(Initialize()); - - if (version_.UsesHttp3()) { - return; - } - - EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); - // Set reordering to ensure that body arriving before PUSH_PROMISE is ok. - SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(2)); - SetReorderPercentage(30); - - // Add a response with headers, body, and push resources. - const std::string kBody = "body content"; - size_t const kNumResources = 4; - std::string push_urls[] = { - "https://example.com/font.woff", - "https://example.com/script.js", - "https://fonts.example.com/font.woff", - "https://example.com/logo-hires.jpg", - }; - AddRequestAndResponseWithServerPush("example.com", "/push_example", kBody, - push_urls, kNumResources, 0); - client_->client()->set_response_listener( - std::unique_ptr( - new TestResponseListener)); - - // Send the first request: this will trigger the server to send all the push - // resources associated with this request, and these will be cached by the - // client. - EXPECT_EQ(kBody, client_->SendSynchronousRequest( - "https://example.com/push_example")); - - for (const std::string& url : push_urls) { - // Sending subsequent requesets will not actually send anything on the wire, - // as the responses are already in the client's cache. - QUIC_DVLOG(1) << "send request for pushed stream on url " << url; - std::string expected_body = - absl::StrCat("This is server push response body for ", url); - std::string response_body = client_->SendSynchronousRequest(url); - QUIC_DVLOG(1) << "response body " << response_body; - EXPECT_EQ(expected_body, response_body); - } - // Expect only original request has been sent and push responses have been - // received as normal response. - EXPECT_EQ(1u, client_->num_requests()); - EXPECT_EQ(1u + kNumResources, client_->num_responses()); -} - -TEST_P(EndToEndTestServerPush, ServerPushOverLimitNonBlocking) { - if (version_.UsesHttp3()) { - ASSERT_TRUE(Initialize()); - return; - } - // Tests that when streams are not blocked by flow control or congestion - // control, pushing even more resources than max number of open outgoing - // streams should still work because all response streams get closed - // immediately after pushing resources. - ASSERT_TRUE(Initialize()); - EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); - - // Set reordering to ensure that body arriving before PUSH_PROMISE is ok. - SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(2)); - SetReorderPercentage(30); - - // Add a response with headers, body, and push resources. - const std::string kBody = "body content"; - - // One more resource than max number of outgoing stream of this session. - const size_t kNumResources = 1 + kNumMaxStreams; // 11. - std::string push_urls[11]; - for (size_t i = 0; i < kNumResources; ++i) { - push_urls[i] = absl::StrCat("https://example.com/push_resources", i); - } - AddRequestAndResponseWithServerPush("example.com", "/push_example", kBody, - push_urls, kNumResources, 0); - client_->client()->set_response_listener( - std::unique_ptr( - new TestResponseListener)); - - // Send the first request: this will trigger the server to send all the push - // resources associated with this request, and these will be cached by the - // client. - EXPECT_EQ(kBody, client_->SendSynchronousRequest( - "https://example.com/push_example")); - - for (const std::string& url : push_urls) { - // Sending subsequent requesets will not actually send anything on the wire, - // as the responses are already in the client's cache. - EXPECT_EQ(absl::StrCat("This is server push response body for ", url), - client_->SendSynchronousRequest(url)); - } - - // Only 1 request should have been sent. - EXPECT_EQ(1u, client_->num_requests()); - // The responses to the original request and all the promised resources - // should have been received. - EXPECT_EQ(12u, client_->num_responses()); -} - -TEST_P(EndToEndTestServerPush, ServerPushOverLimitWithBlocking) { - if (version_.UsesHttp3()) { - ASSERT_TRUE(Initialize()); - return; - } - - // Tests that when server tries to send more large resources(large enough to - // be blocked by flow control window or congestion control window) than max - // open outgoing streams , server can open upto max number of outgoing - // streams for them, and the rest will be queued up. - - // Reset flow control windows. - size_t kFlowControlWnd = 20 * 1024; // 20KB. - // Response body is larger than 1 flow controlblock window. - size_t kBodySize = kFlowControlWnd * 2; - set_client_initial_stream_flow_control_receive_window(kFlowControlWnd); - // Make sure conntection level flow control window is large enough not to - // block data being sent out though they will be blocked by stream level one. - set_client_initial_session_flow_control_receive_window( - kBodySize * kNumMaxStreams + 1024); - - ASSERT_TRUE(Initialize()); - EXPECT_TRUE(client_->client()->WaitForOneRttKeysAvailable()); - - // Set reordering to ensure that body arriving before PUSH_PROMISE is ok. - SetPacketSendDelay(QuicTime::Delta::FromMilliseconds(2)); - SetReorderPercentage(30); - - // Add a response with headers, body, and push resources. - const std::string kBody = "body content"; - - const size_t kNumResources = kNumMaxStreams + 1; - std::string push_urls[11]; - for (size_t i = 0; i < kNumResources; ++i) { - push_urls[i] = absl::StrCat("http://example.com/push_resources", i); - } - AddRequestAndResponseWithServerPush("example.com", "/push_example", kBody, - push_urls, kNumResources, kBodySize); - - client_->client()->set_response_listener( - std::unique_ptr( - new TestResponseListener)); - - client_->SendRequest("https://example.com/push_example"); - - // Pause after the first response arrives. - while (!client_->response_complete()) { - // Because of priority, the first response arrived should be to original - // request. - client_->WaitForResponse(); - ASSERT_TRUE(client_->connected()); - } - - // Check server session to see if it has max number of outgoing streams opened - // though more resources need to be pushed. - if (!version_.HasIetfQuicFrames()) { - server_thread_->Pause(); - QuicSession* server_session = GetServerSession(); - if (server_session != nullptr) { - EXPECT_EQ(kNumMaxStreams, - QuicSessionPeer::GetStreamIdManager(server_session) - ->num_open_outgoing_streams()); - } else { - ADD_FAILURE() << "Missing server session"; - } - server_thread_->Resume(); - } - - EXPECT_EQ(1u, client_->num_requests()); - EXPECT_EQ(1u, client_->num_responses()); - EXPECT_EQ(kBody, client_->response_body()); - - // "Send" request for a promised resources will not really send out it because - // its response is being pushed(but blocked). And the following ack and - // flow control behavior of SendSynchronousRequests() - // will unblock the stream to finish receiving response. - client_->SendSynchronousRequest(push_urls[0]); - EXPECT_EQ(1u, client_->num_requests()); - EXPECT_EQ(2u, client_->num_responses()); - - // Do same thing for the rest 10 resources. - for (size_t i = 1; i < kNumResources; ++i) { - client_->SendSynchronousRequest(push_urls[i]); - } - - // Because of server push, client gets all pushed resources without actually - // sending requests for them. - EXPECT_EQ(1u, client_->num_requests()); - // Including response to original request, 12 responses in total were - // received. - EXPECT_EQ(12u, client_->num_responses()); -} - // TODO(fayang): this test seems to cause net_unittests timeouts :| TEST_P(EndToEndTest, DISABLED_TestHugePostWithPacketLoss) { // This test tests a huge post with introduced packet loss from client to @@ -4533,11 +4885,7 @@ TEST_P(EndToEndTest, client_.reset(CreateQuicClient(client_writer_)); EXPECT_EQ("", client_->SendSynchronousRequest("/foo")); - if (GetQuicReloadableFlag(quic_fix_dispatcher_sent_error_code)) { - EXPECT_THAT(client_->connection_error(), IsError(QUIC_PACKET_WRITE_ERROR)); - } else { - EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_FAILED)); - } + EXPECT_THAT(client_->connection_error(), IsError(QUIC_HANDSHAKE_FAILED)); } // Regression test for b/116200989. @@ -4787,20 +5135,15 @@ TEST_P(EndToEndTest, SendMessages) { ASSERT_LT(0, client_session->GetCurrentLargestMessagePayload()); std::string message_string(kMaxOutgoingPacketSize, 'a'); - absl::string_view message_buffer(message_string); QuicRandom* random = QuicConnectionPeer::GetHelper(client_connection)->GetRandomGenerator(); - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); { QuicConnection::ScopedPacketFlusher flusher(client_session->connection()); // Verify the largest message gets successfully sent. EXPECT_EQ(MessageResult(MESSAGE_STATUS_SUCCESS, 1), - client_session->SendMessage(MakeSpan( - client_connection->helper()->GetStreamSendBufferAllocator(), - absl::string_view( - message_buffer.data(), - client_session->GetCurrentLargestMessagePayload()), - &storage))); + client_session->SendMessage(MemSliceFromString(absl::string_view( + message_string.data(), + client_session->GetCurrentLargestMessagePayload())))); // Send more messages with size (0, largest_payload] until connection is // write blocked. const int kTestMaxNumberOfMessages = 100; @@ -4809,9 +5152,8 @@ TEST_P(EndToEndTest, SendMessages) { random->RandUint64() % client_session->GetGuaranteedLargestMessagePayload() + 1; - MessageResult result = client_session->SendMessage(MakeSpan( - client_connection->helper()->GetStreamSendBufferAllocator(), - absl::string_view(message_buffer.data(), message_length), &storage)); + MessageResult result = client_session->SendMessage(MemSliceFromString( + absl::string_view(message_string.data(), message_length))); if (result.status == MESSAGE_STATUS_BLOCKED) { // Connection is write blocked. break; @@ -4823,12 +5165,9 @@ TEST_P(EndToEndTest, SendMessages) { client_->WaitForDelayedAcks(); EXPECT_EQ(MESSAGE_STATUS_TOO_LARGE, client_session - ->SendMessage(MakeSpan( - client_connection->helper()->GetStreamSendBufferAllocator(), - absl::string_view( - message_buffer.data(), - client_session->GetCurrentLargestMessagePayload() + 1), - &storage)) + ->SendMessage(MemSliceFromString(absl::string_view( + message_string.data(), + client_session->GetCurrentLargestMessagePayload() + 1))) .status); EXPECT_THAT(client_->connection_error(), IsQuicNoError()); } @@ -5100,6 +5439,207 @@ TEST_P(EndToEndPacketReorderingTest, PathValidationFailure) { server_thread_->Resume(); } +TEST_P(EndToEndPacketReorderingTest, MigrateAgainAfterPathValidationFailure) { + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + + client_.reset(CreateQuicClient(nullptr)); + // Finish one request to make sure handshake established. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Wait for the connection to become idle, to make sure the packet gets + // delayed is the connectivity probing packet. + client_->WaitForDelayedAcks(); + + QuicSocketAddress addr1 = client_->client()->session()->self_address(); + QuicConnection* client_connection = GetClientConnection(); + QuicConnectionId server_cid1 = client_connection->connection_id(); + + // Migrate socket to the new IP address. + QuicIpAddress host2 = TestLoopback(2); + EXPECT_NE(addr1.host(), host2); + + // Drop PATH_RESPONSE packets to timeout the path validation. + server_writer_->set_fake_packet_loss_percentage(100); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host2)); + + QuicConnectionId server_cid2 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid2, server_cid1); + // Wait until path validation fails at the client. + while (client_->client()->HasPendingPathValidation()) { + EXPECT_EQ(server_cid2, + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection)); + client_->client()->WaitForEvents(); + } + EXPECT_EQ(addr1, client_->client()->session()->self_address()); + EXPECT_EQ(server_cid1, GetClientConnection()->connection_id()); + + server_writer_->set_fake_packet_loss_percentage(0); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(0u, client_connection->GetStats().num_new_connection_id_sent); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + // Server has received 3 path challenges. + EXPECT_EQ(3u, + server_connection->GetStats().num_connectivity_probing_received); + EXPECT_EQ(server_cid1, server_connection->connection_id()); + EXPECT_EQ(0u, server_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); + + // Migrate socket to a new IP address again. + QuicIpAddress host3 = TestLoopback(3); + EXPECT_NE(addr1.host(), host3); + EXPECT_NE(host2, host3); + + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(0u, client_connection->GetStats().num_new_connection_id_sent); + + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host3)); + QuicConnectionId server_cid3 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host3, client_->client()->session()->self_address().host()); + EXPECT_EQ(server_cid3, GetClientConnection()->connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + // Server should send a new connection ID to client. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(0u, client_connection->GetStats().num_new_connection_id_sent); +} + +TEST_P(EndToEndPacketReorderingTest, + MigrateAgainAfterPathValidationFailureWithNonZeroClientConnectionId) { + if (!version_.SupportsClientConnectionIds()) { + ASSERT_TRUE(Initialize()); + return; + } + override_client_connection_id_length_ = kQuicDefaultConnectionIdLength; + ASSERT_TRUE(Initialize()); + if (!GetClientConnection()->connection_migration_use_new_cid()) { + return; + } + + client_.reset(CreateQuicClient(nullptr)); + // Finish one request to make sure handshake established. + EXPECT_EQ(kFooResponseBody, client_->SendSynchronousRequest("/foo")); + + // Wait for the connection to become idle, to make sure the packet gets + // delayed is the connectivity probing packet. + client_->WaitForDelayedAcks(); + + QuicSocketAddress addr1 = client_->client()->session()->self_address(); + QuicConnection* client_connection = GetClientConnection(); + QuicConnectionId server_cid1 = client_connection->connection_id(); + QuicConnectionId client_cid1 = client_connection->client_connection_id(); + + // Migrate socket to the new IP address. + QuicIpAddress host2 = TestLoopback(2); + EXPECT_NE(addr1.host(), host2); + + // Drop PATH_RESPONSE packets to timeout the path validation. + server_writer_->set_fake_packet_loss_percentage(100); + ASSERT_TRUE( + QuicConnectionPeer::HasUnusedPeerIssuedConnectionId(client_connection)); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host2)); + QuicConnectionId server_cid2 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid2.IsEmpty()); + EXPECT_NE(server_cid2, server_cid1); + QuicConnectionId client_cid2 = + QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(client_cid2.IsEmpty()); + EXPECT_NE(client_cid2, client_cid1); + while (client_->client()->HasPendingPathValidation()) { + EXPECT_EQ(server_cid2, + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection)); + client_->client()->WaitForEvents(); + } + EXPECT_EQ(addr1, client_->client()->session()->self_address()); + EXPECT_EQ(server_cid1, GetClientConnection()->connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + server_writer_->set_fake_packet_loss_percentage(0); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + WaitForNewConnectionIds(); + EXPECT_EQ(1u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, client_connection->GetStats().num_new_connection_id_sent); + + server_thread_->Pause(); + QuicConnection* server_connection = GetServerConnection(); + if (server_connection != nullptr) { + EXPECT_EQ(3u, + server_connection->GetStats().num_connectivity_probing_received); + EXPECT_EQ(server_cid1, server_connection->connection_id()); + } else { + ADD_FAILURE() << "Missing server connection"; + } + EXPECT_EQ(1u, server_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(2u, server_connection->GetStats().num_new_connection_id_sent); + server_thread_->Resume(); + + // Migrate socket to a new IP address again. + QuicIpAddress host3 = TestLoopback(3); + EXPECT_NE(addr1.host(), host3); + EXPECT_NE(host2, host3); + ASSERT_TRUE(client_->client()->ValidateAndMigrateSocket(host3)); + + QuicConnectionId server_cid3 = + QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection); + EXPECT_FALSE(server_cid3.IsEmpty()); + EXPECT_NE(server_cid1, server_cid3); + EXPECT_NE(server_cid2, server_cid3); + QuicConnectionId client_cid3 = + QuicConnectionPeer::GetClientConnectionIdOnAlternativePath( + client_connection); + EXPECT_NE(client_cid1, client_cid3); + EXPECT_NE(client_cid2, client_cid3); + while (client_->client()->HasPendingPathValidation()) { + client_->client()->WaitForEvents(); + } + EXPECT_EQ(host3, client_->client()->session()->self_address().host()); + EXPECT_EQ(server_cid3, GetClientConnection()->connection_id()); + EXPECT_TRUE(QuicConnectionPeer::GetServerConnectionIdOnAlternativePath( + client_connection) + .IsEmpty()); + EXPECT_EQ(kBarResponseBody, client_->SendSynchronousRequest("/bar")); + + // Server should send new server connection ID to client and retires old + // client connection ID. + WaitForNewConnectionIds(); + EXPECT_EQ(2u, client_connection->GetStats().num_retire_connection_id_sent); + EXPECT_EQ(3u, client_connection->GetStats().num_new_connection_id_sent); +} + TEST_P(EndToEndPacketReorderingTest, Buffer0RttRequest) { ASSERT_TRUE(Initialize()); // Finish one request to make sure handshake established. @@ -5373,8 +5913,10 @@ TEST_P(EndToEndTest, CustomTransportParameters) { QuicConfig* server_config = nullptr; if (server_session != nullptr) { server_config = server_session->config(); - EXPECT_EQ(server_session->user_agent_id().value_or("MissingUserAgent"), - kTestUserAgentId); + if (!GetQuicReloadableFlag(quic_ignore_user_agent_transport_parameter)) { + EXPECT_EQ(server_session->user_agent_id().value_or("MissingUserAgent"), + kTestUserAgentId); + } } else { ADD_FAILURE() << "Missing server session"; } @@ -5477,6 +6019,70 @@ TEST_P(EndToEndTest, LegacyVersionEncapsulationWithLoss) { 0u); } +// Testing packet writer that makes a copy of the first sent packets before +// sending them. Useful for tests that need access to sent packets. +class CopyingPacketWriter : public PacketDroppingTestWriter { + public: + explicit CopyingPacketWriter(int num_packets_to_copy) + : num_packets_to_copy_(num_packets_to_copy) {} + WriteResult WritePacket(const char* buffer, + size_t buf_len, + const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, + PerPacketOptions* options) override { + if (num_packets_to_copy_ > 0) { + num_packets_to_copy_--; + packets_.push_back( + QuicEncryptedPacket(buffer, buf_len, /*owns_buffer=*/false).Clone()); + } + return PacketDroppingTestWriter::WritePacket(buffer, buf_len, self_address, + peer_address, options); + } + + std::vector>& packets() { + return packets_; + } + + private: + int num_packets_to_copy_; + std::vector> packets_; +}; + +TEST_P(EndToEndTest, ChaosProtectionDisabled) { + if (!version_.UsesCryptoFrames()) { + ASSERT_TRUE(Initialize()); + return; + } + // Replace the client's writer with one that'll save the first packet. + auto copying_writer = new CopyingPacketWriter(1); + delete client_writer_; + client_writer_ = copying_writer; + // Disable chaos protection and perform an HTTP request. + client_config_.SetClientConnectionOptions(QuicTagVector{kNCHP}); + ASSERT_TRUE(Initialize()); + SendSynchronousFooRequestAndCheckResponse(); + // Parse the saved packet to make sure it's valid. + SimpleQuicFramer validation_framer({version_}); + validation_framer.framer()->SetInitialObfuscators( + GetClientConnection()->connection_id()); + ASSERT_GT(copying_writer->packets().size(), 0u); + EXPECT_TRUE(validation_framer.ProcessPacket(*copying_writer->packets()[0])); + // TODO(dschinazi) figure out a way to use a MockRandom in this test so we + // can inspect the contents of this packet. +} + +TEST_P(EndToEndTest, DisablePermuteTlsExtensions) { + if (!version_.UsesTls()) { + ASSERT_TRUE(Initialize()); + return; + } + // Disable TLS extension permutation and perform an HTTP request. + client_config_.SetClientConnectionOptions(QuicTagVector{kNBPE}); + ASSERT_TRUE(Initialize()); + EXPECT_FALSE(GetClientSession()->permutes_tls_extensions()); + SendSynchronousFooRequestAndCheckResponse(); +} + TEST_P(EndToEndTest, KeyUpdateInitiatedByClient) { if (!version_.UsesTls()) { // Key Update is only supported in TLS handshake. @@ -5666,7 +6272,7 @@ TEST_P(EndToEndTest, KeyUpdateInitiatedByBoth) { } TEST_P(EndToEndTest, KeyUpdateInitiatedByConfidentialityLimit) { - SetQuicFlag(FLAGS_quic_key_update_confidentiality_limit, 4U); + SetQuicFlag(FLAGS_quic_key_update_confidentiality_limit, 16U); if (!version_.UsesTls()) { // Key Update is only supported in TLS handshake. @@ -5692,9 +6298,11 @@ TEST_P(EndToEndTest, KeyUpdateInitiatedByConfidentialityLimit) { }, QuicTime::Delta::FromSeconds(5)); - SendSynchronousFooRequestAndCheckResponse(); - SendSynchronousFooRequestAndCheckResponse(); - SendSynchronousFooRequestAndCheckResponse(); + for (uint64_t i = 0; + i < GetQuicFlag(FLAGS_quic_key_update_confidentiality_limit); ++i) { + SendSynchronousFooRequestAndCheckResponse(); + } + // Don't know exactly how many packets will be sent in each request/response, // so just test that at least one key update occurred. EXPECT_LE(1u, client_connection->GetStats().key_update_count); @@ -5804,7 +6412,7 @@ TEST_P(EndToEndTest, TlsResumptionDisabledOnTheFly) { client_->Disconnect(); if (early_data_reason != ssl_early_data_session_not_resumed) { - EXPECT_EQ(early_data_reason, ssl_early_data_no_session_offered); + EXPECT_EQ(early_data_reason, ssl_early_data_unsupported_for_session); return; } } @@ -5822,12 +6430,37 @@ TEST_P(EndToEndTest, WebTransportSessionSetup) { WebTransportHttp3* web_transport = CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_NE(web_transport, nullptr); + + server_thread_->Pause(); + QuicSpdySession* server_session = GetServerSession(); + EXPECT_TRUE(server_session->GetWebTransportSession(web_transport->id()) != + nullptr); + server_thread_->Resume(); +} + +TEST_P(EndToEndTest, WebTransportSessionSetupWithEchoWithSuffix) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + // "/echoFoo" should be accepted as "echo" with "set-header" query. + WebTransportHttp3* web_transport = CreateWebTransportSession( + "/echoFoo?set-header=bar:baz", /*wait_for_server_response=*/true); + ASSERT_NE(web_transport, nullptr); server_thread_->Pause(); QuicSpdySession* server_session = GetServerSession(); EXPECT_TRUE(server_session->GetWebTransportSession(web_transport->id()) != nullptr); server_thread_->Resume(); + const spdy::SpdyHeaderBlock* response_headers = client_->response_headers(); + auto it = response_headers->find("bar"); + EXPECT_NE(it, response_headers->end()); + EXPECT_EQ(it->second, "baz"); } TEST_P(EndToEndTest, WebTransportSessionWithLoss) { @@ -5843,6 +6476,7 @@ TEST_P(EndToEndTest, WebTransportSessionWithLoss) { WebTransportHttp3* web_transport = CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_NE(web_transport, nullptr); server_thread_->Pause(); QuicSpdySession* server_session = GetServerSession(); @@ -5867,6 +6501,13 @@ TEST_P(EndToEndTest, WebTransportSessionUnidirectionalStream) { WebTransportStream* outgoing_stream = session->OpenOutgoingUnidirectionalStream(); ASSERT_TRUE(outgoing_stream != nullptr); + + auto stream_visitor = std::make_unique>(); + bool data_acknowledged = false; + EXPECT_CALL(*stream_visitor, OnWriteSideInDataRecvdState()) + .WillOnce(Assign(&data_acknowledged, true)); + outgoing_stream->SetVisitor(std::move(stream_visitor)); + EXPECT_TRUE(outgoing_stream->Write("test")); EXPECT_TRUE(outgoing_stream->SendFin()); @@ -5882,6 +6523,10 @@ TEST_P(EndToEndTest, WebTransportSessionUnidirectionalStream) { WebTransportStream::ReadResult result = received_stream->Read(&received_data); EXPECT_EQ(received_data, "test"); EXPECT_TRUE(result.fin); + + client_->WaitUntil(2000, + [&data_acknowledged]() { return data_acknowledged; }); + EXPECT_TRUE(data_acknowledged); } TEST_P(EndToEndTest, WebTransportSessionUnidirectionalStreamSentEarly) { @@ -5932,11 +6577,24 @@ TEST_P(EndToEndTest, WebTransportSessionBidirectionalStream) { WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); ASSERT_TRUE(stream != nullptr); + + auto stream_visitor_owned = std::make_unique>(); + MockStreamVisitor* stream_visitor = stream_visitor_owned.get(); + bool data_acknowledged = false; + EXPECT_CALL(*stream_visitor, OnWriteSideInDataRecvdState()) + .WillOnce(Assign(&data_acknowledged, true)); + stream->SetVisitor(std::move(stream_visitor_owned)); + EXPECT_TRUE(stream->Write("test")); EXPECT_TRUE(stream->SendFin()); - std::string received_data = ReadDataFromWebTransportStreamUntilFin(stream); + std::string received_data = + ReadDataFromWebTransportStreamUntilFin(stream, stream_visitor); EXPECT_EQ(received_data, "test"); + + client_->WaitUntil(2000, + [&data_acknowledged]() { return data_acknowledged; }); + EXPECT_TRUE(data_acknowledged); } TEST_P(EndToEndTest, WebTransportSessionBidirectionalStreamWithBuffering) { @@ -6004,11 +6662,37 @@ TEST_P(EndToEndTest, WebTransportDatagrams) { SimpleBufferAllocator allocator; for (int i = 0; i < 10; i++) { - absl::string_view datagram = "test"; - auto buffer = MakeUniqueBuffer(&allocator, datagram.size()); - memcpy(buffer.get(), datagram.data(), datagram.size()); - QuicMemSlice slice(std::move(buffer), datagram.size()); - session->SendOrQueueDatagram(std::move(slice)); + session->SendOrQueueDatagram(MemSliceFromString("test")); + } + + int received = 0; + EXPECT_CALL(visitor, OnDatagramReceived(_)).WillRepeatedly([&received]() { + received++; + }); + client_->WaitUntil(5000, [&received]() { return received > 0; }); + EXPECT_GT(received, 0); +} + +TEST_P(EndToEndTest, WebTransportDatagramsWithContexts) { + enable_web_transport_ = true; + use_datagram_contexts_ = true; + SetPacketLossPercentage(30); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + QuicSpdyStream* connect_stream = nullptr; + WebTransportHttp3* session = CreateWebTransportSession( + "/echo", /*wait_for_server_response=*/true, &connect_stream); + ASSERT_TRUE(session != nullptr); + ASSERT_TRUE(connect_stream != nullptr); + NiceMock& visitor = SetupWebTransportVisitor(session); + + SimpleBufferAllocator allocator; + for (int i = 0; i < 10; i++) { + session->SendOrQueueDatagram(MemSliceFromString("test")); } int received = 0; @@ -6017,6 +6701,311 @@ TEST_P(EndToEndTest, WebTransportDatagrams) { }); client_->WaitUntil(5000, [&received]() { return received > 0; }); EXPECT_GT(received, 0); + EXPECT_TRUE(QuicSpdyStreamPeer::use_datagram_contexts(connect_stream)); +} + +TEST_P(EndToEndTest, WebTransportSessionClose) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = SetupWebTransportVisitor(session); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QuicStreamId stream_id = stream->GetStreamId(); + EXPECT_TRUE(stream->Write("test")); + // Keep stream open. + + bool close_received = false; + EXPECT_CALL(visitor, OnSessionClosed(42, "test error")) + .WillOnce(Assign(&close_received, true)); + session->CloseSession(42, "test error"); + client_->WaitUntil(2000, [&]() { return close_received; }); + EXPECT_TRUE(close_received); + + QuicSpdyStream* spdy_stream = + GetClientSession()->GetOrCreateSpdyDataStream(stream_id); + EXPECT_TRUE(spdy_stream == nullptr); +} + +TEST_P(EndToEndTest, WebTransportSessionCloseWithoutCapsule) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/echo", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = SetupWebTransportVisitor(session); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QuicStreamId stream_id = stream->GetStreamId(); + EXPECT_TRUE(stream->Write("test")); + // Keep stream open. + + bool close_received = false; + EXPECT_CALL(visitor, OnSessionClosed(0, "")) + .WillOnce(Assign(&close_received, true)); + session->CloseSessionWithFinOnlyForTests(); + client_->WaitUntil(2000, [&]() { return close_received; }); + EXPECT_TRUE(close_received); + + QuicSpdyStream* spdy_stream = + GetClientSession()->GetOrCreateSpdyDataStream(stream_id); + EXPECT_TRUE(spdy_stream == nullptr); +} + +TEST_P(EndToEndTest, WebTransportSessionReceiveClose) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = CreateWebTransportSession( + "/session-close", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + NiceMock& visitor = SetupWebTransportVisitor(session); + + WebTransportStream* stream = session->OpenOutgoingUnidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + QuicStreamId stream_id = stream->GetStreamId(); + EXPECT_TRUE(stream->Write("42 test error")); + EXPECT_TRUE(stream->SendFin()); + + // Have some other streams open pending, to ensure they are closed properly. + stream = session->OpenOutgoingUnidirectionalStream(); + stream = session->OpenOutgoingBidirectionalStream(); + + bool close_received = false; + EXPECT_CALL(visitor, OnSessionClosed(42, "test error")) + .WillOnce(Assign(&close_received, true)); + client_->WaitUntil(2000, [&]() { return close_received; }); + EXPECT_TRUE(close_received); + + QuicSpdyStream* spdy_stream = + GetClientSession()->GetOrCreateSpdyDataStream(stream_id); + EXPECT_TRUE(spdy_stream == nullptr); +} + +TEST_P(EndToEndTest, WebTransportSessionStreamTermination) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = + CreateWebTransportSession("/resets", /*wait_for_server_response=*/true); + ASSERT_TRUE(session != nullptr); + + NiceMock& visitor = SetupWebTransportVisitor(session); + EXPECT_CALL(visitor, OnIncomingUnidirectionalStreamAvailable()) + .WillRepeatedly([this, session]() { + ReadAllIncomingWebTransportUnidirectionalStreams(session); + }); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + QuicStreamId id1 = stream->GetStreamId(); + ASSERT_TRUE(stream != nullptr); + EXPECT_TRUE(stream->Write("test")); + stream->ResetWithUserCode(42); + + // This read fails if the stream is closed in both directions, since that + // results in stream object being deleted. + std::string received_data = ReadDataFromWebTransportStreamUntilFin(stream); + EXPECT_LE(received_data.size(), 4u); + + stream = session->OpenOutgoingBidirectionalStream(); + QuicStreamId id2 = stream->GetStreamId(); + ASSERT_TRUE(stream != nullptr); + EXPECT_TRUE(stream->Write("test")); + stream->SendStopSending(24); + + std::array expected_log = { + absl::StrCat("Received reset for stream ", id1, " with error code 42"), + absl::StrCat("Received stop sending for stream ", id2, + " with error code 24"), + }; + client_->WaitUntil(2000, [this, &expected_log]() { + return received_webtransport_unidirectional_streams_.size() >= + expected_log.size(); + }); + EXPECT_THAT(received_webtransport_unidirectional_streams_, + UnorderedElementsAreArray(expected_log)); + + // Since we closed the read side, cleanly closing the write side should result + // in the stream getting deleted. + ASSERT_TRUE(GetClientSession()->GetOrCreateSpdyDataStream(id2) != nullptr); + EXPECT_TRUE(stream->SendFin()); + EXPECT_TRUE(client_->WaitUntil(2000, [this, id2]() { + return GetClientSession()->GetOrCreateSpdyDataStream(id2) == nullptr; + })); +} + +TEST_P(EndToEndTest, WebTransportSession404) { + enable_web_transport_ = true; + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + + WebTransportHttp3* session = CreateWebTransportSession( + "/does-not-exist", /*wait_for_server_response=*/false); + ASSERT_TRUE(session != nullptr); + QuicSpdyStream* connect_stream = client_->latest_created_stream(); + QuicStreamId connect_stream_id = connect_stream->id(); + + WebTransportStream* stream = session->OpenOutgoingBidirectionalStream(); + ASSERT_TRUE(stream != nullptr); + EXPECT_TRUE(stream->Write("test")); + EXPECT_TRUE(stream->SendFin()); + + EXPECT_TRUE(client_->WaitUntil(-1, [this, connect_stream_id]() { + return GetClientSession()->GetOrCreateSpdyDataStream(connect_stream_id) == + nullptr; + })); +} + +TEST_P(EndToEndTest, InvalidExtendedConnect) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + // Missing :path header. + spdy::SpdyHeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "CONNECT"; + headers[":protocol"] = "webtransport"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + // An early response should be received. + CheckResponseHeaders("400"); +} + +TEST_P(EndToEndTest, RejectExtendedConnect) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // Disable extended CONNECT. + memory_cache_backend_.set_enable_extended_connect(false); + ASSERT_TRUE(Initialize()); + + if (!version_.UsesHttp3()) { + return; + } + // This extended CONNECT should be rejected. + spdy::SpdyHeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "CONNECT"; + headers[":path"] = "/echo"; + headers[":protocol"] = "webtransport"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); + + // Vanilla CONNECT should be accepted. + spdy::SpdyHeaderBlock headers2; + headers2[":authority"] = "localhost"; + headers2[":method"] = "CONNECT"; + + client_->SendMessage(headers2, "body", /*fin=*/true); + client_->WaitForResponse(); + // No :path header, so 404. + CheckResponseHeaders("404"); +} + +TEST_P(EndToEndTest, RejectInvalidRequestHeader) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + spdy::SpdyHeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "GET"; + headers[":path"] = "/echo"; + // transfer-encoding header is not allowed. + headers["transfer-encoding"] = "chunk"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); +} + +TEST_P(EndToEndTest, RejectTransferEncodingResponse) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + // Add a response with transfer-encoding headers. + SpdyHeaderBlock headers; + headers[":status"] = "200"; + headers["transfer-encoding"] = "gzip"; + + SpdyHeaderBlock trailers; + trailers["some-trailing-header"] = "trailing-header-value"; + + memory_cache_backend_.AddResponse(server_hostname_, "/eep", + std::move(headers), "", trailers.Clone()); + + std::string received_response = client_->SendSynchronousRequest("/eep"); + EXPECT_THAT(client_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +TEST_P(EndToEndTest, RejectUpperCaseRequest) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + spdy::SpdyHeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "GET"; + headers[":path"] = "/echo"; + headers["UpperCaseHeader"] = "foo"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); +} + +TEST_P(EndToEndTest, RejectRequestWithInvalidToken) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + ASSERT_TRUE(Initialize()); + + spdy::SpdyHeaderBlock headers; + headers[":scheme"] = "https"; + headers[":authority"] = "localhost"; + headers[":method"] = "GET"; + headers[":path"] = "/echo"; + headers["invalid,header"] = "foo"; + + client_->SendMessage(headers, "", /*fin=*/false); + client_->WaitForResponse(); + CheckResponseHeaders("400"); } } // namespace diff --git a/gquiche/quic/core/http/http_constants.cc b/gquiche/quic/core/http/http_constants.cc index 4f60153a..f2009127 100644 --- a/gquiche/quic/core/http/http_constants.cc +++ b/gquiche/quic/core/http/http_constants.cc @@ -17,8 +17,10 @@ std::string H3SettingsToString(Http3AndQpackSettingsIdentifiers identifier) { RETURN_STRING_LITERAL(SETTINGS_QPACK_MAX_TABLE_CAPACITY); RETURN_STRING_LITERAL(SETTINGS_MAX_FIELD_SECTION_SIZE); RETURN_STRING_LITERAL(SETTINGS_QPACK_BLOCKED_STREAMS); - RETURN_STRING_LITERAL(SETTINGS_H3_DATAGRAM); + RETURN_STRING_LITERAL(SETTINGS_H3_DATAGRAM_DRAFT00); + RETURN_STRING_LITERAL(SETTINGS_H3_DATAGRAM_DRAFT04); RETURN_STRING_LITERAL(SETTINGS_WEBTRANS_DRAFT00); + RETURN_STRING_LITERAL(SETTINGS_ENABLE_CONNECT_PROTOCOL); } return absl::StrCat("UNSUPPORTED_SETTINGS_TYPE(", identifier, ")"); } diff --git a/gquiche/quic/core/http/http_constants.h b/gquiche/quic/core/http/http_constants.h index 967bc8ab..07701b24 100644 --- a/gquiche/quic/core/http/http_constants.h +++ b/gquiche/quic/core/http/http_constants.h @@ -38,10 +38,14 @@ enum Http3AndQpackSettingsIdentifiers : uint64_t { // Same value as spdy::SETTINGS_MAX_HEADER_LIST_SIZE. SETTINGS_MAX_FIELD_SECTION_SIZE = 0x06, SETTINGS_QPACK_BLOCKED_STREAMS = 0x07, - // draft-ietf-masque-h3-datagram. - SETTINGS_H3_DATAGRAM = 0x276, + // draft-ietf-masque-h3-datagram-00. + SETTINGS_H3_DATAGRAM_DRAFT00 = 0x276, + // draft-ietf-masque-h3-datagram-04. + SETTINGS_H3_DATAGRAM_DRAFT04 = 0xffd277, // draft-ietf-webtrans-http3-00 SETTINGS_WEBTRANS_DRAFT00 = 0x2b603742, + // draft-ietf-httpbis-h3-websockets + SETTINGS_ENABLE_CONNECT_PROTOCOL = 0x08, }; // Returns HTTP/3 SETTINGS identifier as a string. diff --git a/gquiche/quic/core/http/http_decoder.cc b/gquiche/quic/core/http/http_decoder.cc index 01d4abb2..6028d723 100644 --- a/gquiche/quic/core/http/http_decoder.cc +++ b/gquiche/quic/core/http/http_decoder.cc @@ -5,7 +5,6 @@ #include "gquiche/quic/core/http/http_decoder.h" #include -#include #include "absl/base/attributes.h" #include "absl/strings/string_view.h" @@ -21,6 +20,17 @@ namespace quic { +namespace { + +// Limit on the payload length for frames that are buffered by HttpDecoder. +// If a frame header indicating a payload length exceeding this limit is +// received, HttpDecoder closes the connection. Does not apply to frames that +// are not buffered here but each payload fragment is immediately passed to +// Visitor, like HEADERS, DATA, and unknown frames. +constexpr QuicByteCount kPayloadLengthLimit = 1024 * 1024; + +} // anonymous namespace + HttpDecoder::HttpDecoder(Visitor* visitor) : HttpDecoder(visitor, Options()) {} HttpDecoder::HttpDecoder(Visitor* visitor, Options options) : visitor_(visitor), @@ -33,8 +43,6 @@ HttpDecoder::HttpDecoder(Visitor* visitor, Options options) remaining_frame_length_(0), current_type_field_length_(0), remaining_type_field_length_(0), - current_push_id_length_(0), - remaining_push_id_length_(0), error_(QUIC_NO_ERROR), error_detail_("") { QUICHE_DCHECK(visitor_); @@ -92,8 +100,11 @@ QuicByteCount HttpDecoder::ProcessInput(const char* data, QuicByteCount len) { QuicDataReader reader(data, len); bool continue_processing = true; - while (continue_processing && - (reader.BytesRemaining() != 0 || state_ == STATE_FINISH_PARSING)) { + // BufferOrParsePayload() and FinishParsing() may need to be called even if + // there is no more data so that they can finish processing the current frame. + while (continue_processing && (reader.BytesRemaining() != 0 || + state_ == STATE_BUFFER_OR_PARSE_PAYLOAD || + state_ == STATE_FINISH_PARSING)) { // |continue_processing| must have been set to false upon error. QUICHE_DCHECK_EQ(QUIC_NO_ERROR, error_); QUICHE_DCHECK_NE(STATE_ERROR, state_); @@ -105,11 +116,14 @@ QuicByteCount HttpDecoder::ProcessInput(const char* data, QuicByteCount len) { case STATE_READING_FRAME_LENGTH: continue_processing = ReadFrameLength(&reader); break; + case STATE_BUFFER_OR_PARSE_PAYLOAD: + continue_processing = BufferOrParsePayload(&reader); + break; case STATE_READING_FRAME_PAYLOAD: continue_processing = ReadFramePayload(&reader); break; case STATE_FINISH_PARSING: - continue_processing = FinishParsing(&reader); + continue_processing = FinishParsing(); break; case STATE_PARSING_NO_LONGER_POSSIBLE: continue_processing = false; @@ -172,6 +186,18 @@ bool HttpDecoder::ReadFrameType(QuicDataReader* reader) { current_frame_type_)); return false; } + + if (current_frame_type_ == + static_cast(HttpFrameType::CANCEL_PUSH)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "CANCEL_PUSH frame received."); + return false; + } + if (current_frame_type_ == + static_cast(HttpFrameType::PUSH_PROMISE)) { + RaiseError(QUIC_HTTP_FRAME_ERROR, "PUSH_PROMISE frame received."); + return false; + } + state_ = STATE_READING_FRAME_LENGTH; return true; } @@ -217,7 +243,8 @@ bool HttpDecoder::ReadFrameLength(QuicDataReader* reader) { return false; } - if (current_frame_length_ > MaxFrameLength(current_frame_type_)) { + if (IsFrameBuffered() && + current_frame_length_ > MaxFrameLength(current_frame_type_)) { RaiseError(QUIC_HTTP_FRAME_TOO_LARGE, "Frame is too large."); return false; } @@ -238,27 +265,18 @@ bool HttpDecoder::ReadFrameLength(QuicDataReader* reader) { visitor_->OnHeadersFrameStart(header_length, current_frame_length_); break; case static_cast(HttpFrameType::CANCEL_PUSH): + QUICHE_NOTREACHED(); break; case static_cast(HttpFrameType::SETTINGS): continue_processing = visitor_->OnSettingsFrameStart(header_length); break; case static_cast(HttpFrameType::PUSH_PROMISE): - // This edge case needs to be handled here, because ReadFramePayload() - // does not get called if |current_frame_length_| is zero. - if (current_frame_length_ == 0) { - RaiseError(QUIC_HTTP_FRAME_ERROR, - "PUSH_PROMISE frame with empty payload."); - return false; - } - continue_processing = visitor_->OnPushPromiseFrameStart(header_length); + QUICHE_NOTREACHED(); break; case static_cast(HttpFrameType::GOAWAY): break; case static_cast(HttpFrameType::MAX_PUSH_ID): break; - case static_cast(HttpFrameType::PRIORITY_UPDATE): - continue_processing = visitor_->OnPriorityUpdateFrameStart(header_length); - break; case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): continue_processing = visitor_->OnPriorityUpdateFrameStart(header_length); break; @@ -272,12 +290,37 @@ bool HttpDecoder::ReadFrameLength(QuicDataReader* reader) { } remaining_frame_length_ = current_frame_length_; + + if (IsFrameBuffered()) { + state_ = STATE_BUFFER_OR_PARSE_PAYLOAD; + return continue_processing; + } + state_ = (remaining_frame_length_ == 0) ? STATE_FINISH_PARSING : STATE_READING_FRAME_PAYLOAD; return continue_processing; } +bool HttpDecoder::IsFrameBuffered() { + switch (current_frame_type_) { + case static_cast(HttpFrameType::SETTINGS): + return true; + case static_cast(HttpFrameType::GOAWAY): + return true; + case static_cast(HttpFrameType::MAX_PUSH_ID): + return true; + case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): + return true; + case static_cast(HttpFrameType::ACCEPT_CH): + return true; + } + + // Other defined frame types as well as unknown frames are not buffered. + return false; +} + bool HttpDecoder::ReadFramePayload(QuicDataReader* reader) { + QUICHE_DCHECK(!IsFrameBuffered()); QUICHE_DCHECK_NE(0u, reader->BytesRemaining()); QUICHE_DCHECK_NE(0u, remaining_frame_length_); @@ -307,96 +350,31 @@ bool HttpDecoder::ReadFramePayload(QuicDataReader* reader) { break; } case static_cast(HttpFrameType::CANCEL_PUSH): { - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::SETTINGS): { - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::PUSH_PROMISE): { - PushId push_id; - if (current_frame_length_ == remaining_frame_length_) { - // A new Push Promise frame just arrived. - QUICHE_DCHECK_EQ(0u, current_push_id_length_); - current_push_id_length_ = reader->PeekVarInt62Length(); - if (current_push_id_length_ > remaining_frame_length_) { - RaiseError(QUIC_HTTP_FRAME_ERROR, - "Unable to read PUSH_PROMISE push_id."); - return false; - } - if (current_push_id_length_ > reader->BytesRemaining()) { - // Not all bytes of push id is present yet, buffer push id. - QUICHE_DCHECK_EQ(0u, remaining_push_id_length_); - remaining_push_id_length_ = current_push_id_length_; - BufferPushId(reader); - break; - } - bool success = reader->ReadVarInt62(&push_id); - QUICHE_DCHECK(success); - remaining_frame_length_ -= current_push_id_length_; - if (!visitor_->OnPushPromiseFramePushId( - push_id, current_push_id_length_, - current_frame_length_ - current_push_id_length_)) { - continue_processing = false; - current_push_id_length_ = 0; - break; - } - current_push_id_length_ = 0; - } else if (remaining_push_id_length_ > 0) { - // Waiting for more bytes on push id. - BufferPushId(reader); - if (remaining_push_id_length_ != 0) { - break; - } - QuicDataReader push_id_reader(push_id_buffer_.data(), - current_push_id_length_); - - bool success = push_id_reader.ReadVarInt62(&push_id); - QUICHE_DCHECK(success); - if (!visitor_->OnPushPromiseFramePushId( - push_id, current_push_id_length_, - current_frame_length_ - current_push_id_length_)) { - continue_processing = false; - current_push_id_length_ = 0; - break; - } - current_push_id_length_ = 0; - } - - // Read Push Promise headers. - QUICHE_DCHECK_LT(remaining_frame_length_, current_frame_length_); - QuicByteCount bytes_to_read = std::min( - remaining_frame_length_, reader->BytesRemaining()); - if (bytes_to_read == 0) { - break; - } - absl::string_view payload; - bool success = reader->ReadStringPiece(&payload, bytes_to_read); - QUICHE_DCHECK(success); - QUICHE_DCHECK(!payload.empty()); - continue_processing = visitor_->OnPushPromiseFramePayload(payload); - remaining_frame_length_ -= payload.length(); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::GOAWAY): { - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::MAX_PUSH_ID): { - continue_processing = BufferOrParsePayload(reader); - break; - } - case static_cast(HttpFrameType::PRIORITY_UPDATE): { - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): { - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::ACCEPT_CH): { - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } default: { @@ -405,15 +383,15 @@ bool HttpDecoder::ReadFramePayload(QuicDataReader* reader) { } } - // BufferOrParsePayload() may have advanced |state_|. - if (state_ == STATE_READING_FRAME_PAYLOAD && remaining_frame_length_ == 0) { + if (remaining_frame_length_ == 0) { state_ = STATE_FINISH_PARSING; } return continue_processing; } -bool HttpDecoder::FinishParsing(QuicDataReader* reader) { +bool HttpDecoder::FinishParsing() { + QUICHE_DCHECK(!IsFrameBuffered()); QUICHE_DCHECK_EQ(0u, remaining_frame_length_); bool continue_processing = true; @@ -428,59 +406,45 @@ bool HttpDecoder::FinishParsing(QuicDataReader* reader) { break; } case static_cast(HttpFrameType::CANCEL_PUSH): { - // If frame payload is not empty, FinishParsing() is skipped. - QUICHE_DCHECK_EQ(0u, current_frame_length_); - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::SETTINGS): { - // If frame payload is not empty, FinishParsing() is skipped. - QUICHE_DCHECK_EQ(0u, current_frame_length_); - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::PUSH_PROMISE): { - continue_processing = visitor_->OnPushPromiseFrameEnd(); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::GOAWAY): { - // If frame payload is not empty, FinishParsing() is skipped. - QUICHE_DCHECK_EQ(0u, current_frame_length_); - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::MAX_PUSH_ID): { - // If frame payload is not empty, FinishParsing() is skipped. - QUICHE_DCHECK_EQ(0u, current_frame_length_); - continue_processing = BufferOrParsePayload(reader); - break; - } - case static_cast(HttpFrameType::PRIORITY_UPDATE): { - // If frame payload is not empty, FinishParsing() is skipped. - QUICHE_DCHECK_EQ(0u, current_frame_length_); - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): { - // If frame payload is not empty, FinishParsing() is skipped. - QUICHE_DCHECK_EQ(0u, current_frame_length_); - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } case static_cast(HttpFrameType::ACCEPT_CH): { - // If frame payload is not empty, FinishParsing() is skipped. - QUICHE_DCHECK_EQ(0u, current_frame_length_); - continue_processing = BufferOrParsePayload(reader); + QUICHE_NOTREACHED(); break; } default: continue_processing = visitor_->OnUnknownFrameEnd(); } + ResetForNextFrame(); + return continue_processing; +} + +void HttpDecoder::ResetForNextFrame() { current_length_field_length_ = 0; current_type_field_length_ = 0; state_ = STATE_READING_FRAME_TYPE; - return continue_processing; } bool HttpDecoder::HandleUnknownFramePayload(QuicDataReader* reader) { @@ -494,83 +458,56 @@ bool HttpDecoder::HandleUnknownFramePayload(QuicDataReader* reader) { return visitor_->OnUnknownFramePayload(payload); } -void HttpDecoder::DiscardFramePayload(QuicDataReader* reader) { - QuicByteCount bytes_to_read = std::min( - remaining_frame_length_, reader->BytesRemaining()); - absl::string_view payload; - bool success = reader->ReadStringPiece(&payload, bytes_to_read); - QUICHE_DCHECK(success); - remaining_frame_length_ -= payload.length(); - if (remaining_frame_length_ == 0) { - state_ = STATE_READING_FRAME_TYPE; - current_length_field_length_ = 0; - current_type_field_length_ = 0; - } -} - bool HttpDecoder::BufferOrParsePayload(QuicDataReader* reader) { + QUICHE_DCHECK(IsFrameBuffered()); QUICHE_DCHECK_EQ(current_frame_length_, buffer_.size() + remaining_frame_length_); - bool continue_processing = true; - if (buffer_.empty() && reader->BytesRemaining() >= current_frame_length_) { // |*reader| contains entire payload, which might be empty. remaining_frame_length_ = 0; QuicDataReader current_payload_reader(reader->PeekRemainingPayload().data(), current_frame_length_); - continue_processing = ParseEntirePayload(¤t_payload_reader); - reader->Seek(current_frame_length_); - } else { - if (buffer_.empty()) { - buffer_.reserve(current_frame_length_); - } + bool continue_processing = ParseEntirePayload(¤t_payload_reader); - // Buffer as much of the payload as |*reader| contains. - QuicByteCount bytes_to_read = std::min( - remaining_frame_length_, reader->BytesRemaining()); - absl::StrAppend(&buffer_, reader->PeekRemainingPayload().substr( - /* pos = */ 0, bytes_to_read)); - reader->Seek(bytes_to_read); - remaining_frame_length_ -= bytes_to_read; + reader->Seek(current_frame_length_); + ResetForNextFrame(); + return continue_processing; + } - QUICHE_DCHECK_EQ(current_frame_length_, - buffer_.size() + remaining_frame_length_); + // Buffer as much of the payload as |*reader| contains. + QuicByteCount bytes_to_read = std::min( + remaining_frame_length_, reader->BytesRemaining()); + absl::StrAppend(&buffer_, reader->PeekRemainingPayload().substr( + /* pos = */ 0, bytes_to_read)); + reader->Seek(bytes_to_read); + remaining_frame_length_ -= bytes_to_read; - if (remaining_frame_length_ > 0) { - QUICHE_DCHECK(reader->IsDoneReading()); - return true; - } + QUICHE_DCHECK_EQ(current_frame_length_, + buffer_.size() + remaining_frame_length_); - QuicDataReader buffer_reader(buffer_); - continue_processing = ParseEntirePayload(&buffer_reader); - buffer_.clear(); + if (remaining_frame_length_ > 0) { + QUICHE_DCHECK(reader->IsDoneReading()); + return false; } - current_length_field_length_ = 0; - current_type_field_length_ = 0; - state_ = STATE_READING_FRAME_TYPE; + QuicDataReader buffer_reader(buffer_); + bool continue_processing = ParseEntirePayload(&buffer_reader); + buffer_.clear(); + + ResetForNextFrame(); return continue_processing; } bool HttpDecoder::ParseEntirePayload(QuicDataReader* reader) { + QUICHE_DCHECK(IsFrameBuffered()); QUICHE_DCHECK_EQ(current_frame_length_, reader->BytesRemaining()); QUICHE_DCHECK_EQ(0u, remaining_frame_length_); switch (current_frame_type_) { case static_cast(HttpFrameType::CANCEL_PUSH): { - CancelPushFrame frame; - if (!reader->ReadVarInt62(&frame.push_id)) { - RaiseError(QUIC_HTTP_FRAME_ERROR, - "Unable to read CANCEL_PUSH push_id."); - return false; - } - if (!reader->IsDoneReading()) { - RaiseError(QUIC_HTTP_FRAME_ERROR, - "Superfluous data in CANCEL_PUSH frame."); - return false; - } - return visitor_->OnCancelPushFrame(frame); + QUICHE_NOTREACHED(); + return false; } case static_cast(HttpFrameType::SETTINGS): { SettingsFrame frame; @@ -605,16 +542,9 @@ bool HttpDecoder::ParseEntirePayload(QuicDataReader* reader) { } return visitor_->OnMaxPushIdFrame(frame); } - case static_cast(HttpFrameType::PRIORITY_UPDATE): { - PriorityUpdateFrame frame; - if (!ParsePriorityUpdateFrame(reader, &frame)) { - return false; - } - return visitor_->OnPriorityUpdateFrame(frame); - } case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): { PriorityUpdateFrame frame; - if (!ParseNewPriorityUpdateFrame(reader, &frame)) { + if (!ParsePriorityUpdateFrame(reader, &frame)) { return false; } return visitor_->OnPriorityUpdateFrame(frame); @@ -655,19 +585,6 @@ void HttpDecoder::BufferFrameType(QuicDataReader* reader) { remaining_type_field_length_ -= bytes_to_read; } -void HttpDecoder::BufferPushId(QuicDataReader* reader) { - QUICHE_DCHECK_LE(remaining_push_id_length_, current_frame_length_); - QuicByteCount bytes_to_read = std::min( - reader->BytesRemaining(), remaining_push_id_length_); - bool success = - reader->ReadBytes(push_id_buffer_.data() + current_push_id_length_ - - remaining_push_id_length_, - bytes_to_read); - QUICHE_DCHECK(success); - remaining_push_id_length_ -= bytes_to_read; - remaining_frame_length_ -= bytes_to_read; -} - void HttpDecoder::RaiseError(QuicErrorCode error, std::string error_detail) { state_ = STATE_ERROR; error_ = error; @@ -699,36 +616,6 @@ bool HttpDecoder::ParseSettingsFrame(QuicDataReader* reader, } bool HttpDecoder::ParsePriorityUpdateFrame(QuicDataReader* reader, - PriorityUpdateFrame* frame) { - uint8_t prioritized_element_type; - if (!reader->ReadUInt8(&prioritized_element_type)) { - RaiseError(QUIC_HTTP_FRAME_ERROR, - "Unable to read prioritized element type."); - return false; - } - - if (prioritized_element_type != REQUEST_STREAM && - prioritized_element_type != PUSH_STREAM) { - RaiseError(QUIC_HTTP_FRAME_ERROR, "Invalid prioritized element type."); - return false; - } - - frame->prioritized_element_type = - static_cast(prioritized_element_type); - - if (!reader->ReadVarInt62(&frame->prioritized_element_id)) { - RaiseError(QUIC_HTTP_FRAME_ERROR, "Unable to read prioritized element id."); - return false; - } - - absl::string_view priority_field_value = reader->ReadRemainingPayload(); - frame->priority_field_value = - std::string(priority_field_value.data(), priority_field_value.size()); - - return true; -} - -bool HttpDecoder::ParseNewPriorityUpdateFrame(QuicDataReader* reader, PriorityUpdateFrame* frame) { frame->prioritized_element_type = REQUEST_STREAM; @@ -765,28 +652,22 @@ bool HttpDecoder::ParseAcceptChFrame(QuicDataReader* reader, } QuicByteCount HttpDecoder::MaxFrameLength(uint64_t frame_type) { + QUICHE_DCHECK(IsFrameBuffered()); + switch (frame_type) { - case static_cast(HttpFrameType::CANCEL_PUSH): - return sizeof(PushId); case static_cast(HttpFrameType::SETTINGS): - // This limit is arbitrary. - return 1024 * 1024; + return kPayloadLengthLimit; case static_cast(HttpFrameType::GOAWAY): return VARIABLE_LENGTH_INTEGER_LENGTH_8; case static_cast(HttpFrameType::MAX_PUSH_ID): - return sizeof(PushId); - case static_cast(HttpFrameType::PRIORITY_UPDATE): - // This limit is arbitrary. - return 1024 * 1024; + return VARIABLE_LENGTH_INTEGER_LENGTH_8; case static_cast(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM): - // This limit is arbitrary. - return 1024 * 1024; + return kPayloadLengthLimit; case static_cast(HttpFrameType::ACCEPT_CH): - // This limit is arbitrary. - return 1024 * 1024; + return kPayloadLengthLimit; default: - // Other frames require no data buffering, so it's safe to have no limit. - return std::numeric_limits::max(); + QUICHE_NOTREACHED(); + return 0; } } diff --git a/gquiche/quic/core/http/http_decoder.h b/gquiche/quic/core/http/http_decoder.h index 6dc288ea..861a847d 100644 --- a/gquiche/quic/core/http/http_decoder.h +++ b/gquiche/quic/core/http/http_decoder.h @@ -44,9 +44,6 @@ class QUIC_EXPORT_PRIVATE HttpDecoder { // On*FrameStart() methods are called after the frame header is completely // processed. At that point it is safe to consume |header_length| bytes. - // Called when a CANCEL_PUSH frame has been successfully parsed. - virtual bool OnCancelPushFrame(const CancelPushFrame& frame) = 0; - // Called when a MAX_PUSH_ID frame has been successfully parsed. virtual bool OnMaxPushIdFrame(const MaxPushIdFrame& frame) = 0; @@ -83,23 +80,6 @@ class QUIC_EXPORT_PRIVATE HttpDecoder { // Called when a HEADERS frame has been completely processed. virtual bool OnHeadersFrameEnd() = 0; - // Called when a PUSH_PROMISE frame has been received. - virtual bool OnPushPromiseFrameStart(QuicByteCount header_length) = 0; - // Called when the Push ID field of a PUSH_PROMISE frame has been parsed. - // Called exactly once for a valid PUSH_PROMISE frame. - // |push_id_length| is the length of the push ID field. - // |header_block_length| is the length of the compressed header block. - virtual bool OnPushPromiseFramePushId( - PushId push_id, - QuicByteCount push_id_length, - QuicByteCount header_block_length) = 0; - // Called when part of the header block of a PUSH_PROMISE frame has been - // read. May be called multiple times for a single frame. |payload| is - // guaranteed to be non-empty. - virtual bool OnPushPromiseFramePayload(absl::string_view payload) = 0; - // Called when a PUSH_PROMISE frame has been completely processed. - virtual bool OnPushPromiseFrameEnd() = 0; - // Called when a PRIORITY_UPDATE frame has been received. // |header_length| contains PRIORITY_UPDATE frame length and payload length. virtual bool OnPriorityUpdateFrameStart(QuicByteCount header_length) = 0; @@ -175,8 +155,14 @@ class QUIC_EXPORT_PRIVATE HttpDecoder { enum HttpDecoderState { STATE_READING_FRAME_LENGTH, STATE_READING_FRAME_TYPE, + + // States used for buffered frame types + STATE_BUFFER_OR_PARSE_PAYLOAD, + + // States used for non-buffered frame types STATE_READING_FRAME_PAYLOAD, STATE_FINISH_PARSING, + STATE_PARSING_NO_LONGER_POSSIBLE, STATE_ERROR }; @@ -191,35 +177,44 @@ class QUIC_EXPORT_PRIVATE HttpDecoder { // if there are any errors. Returns whether processing should continue. bool ReadFrameLength(QuicDataReader* reader); - // Depending on the frame type, reads and processes the payload of the current - // frame from |reader| and calls visitor methods, or calls - // BufferOrParsePayload(). Returns whether processing should continue. + // Returns whether the current frame is of a buffered type. + // The payload of buffered frames is buffered by HttpDecoder, and parsed by + // HttpDecoder after the entire frame has been received. (Copying to the + // buffer is skipped if the ProcessInput() call covers the entire payload.) + // Frames that are not buffered have every payload fragment synchronously + // passed to the Visitor without buffering. + bool IsFrameBuffered(); + + // For buffered frame types, calls BufferOrParsePayload(). For other frame + // types, reads the payload of the current frame from |reader| and calls + // visitor methods. Returns whether processing should continue. bool ReadFramePayload(QuicDataReader* reader); - // For frame types parsed by BufferOrParsePayload(), this method is only - // called if frame payload is empty, at it calls BufferOrParsePayload(). For - // other frame types, this method directly calls visitor methods to signal - // that frame had been parsed completely. Returns whether processing should - // continue. - bool FinishParsing(QuicDataReader* reader); + // For buffered frame types, this method is only called if frame payload is + // empty, and it calls BufferOrParsePayload(). For other frame types, this + // method directly calls visitor methods to signal that frame had been + // received completely. Returns whether processing should continue. + bool FinishParsing(); + + // Reset internal fields to prepare for reading next frame. + void ResetForNextFrame(); // Read payload of unknown frame from |reader| and call // Visitor::OnUnknownFramePayload(). Returns true decoding should continue, // false if it should be paused. bool HandleUnknownFramePayload(QuicDataReader* reader); - // Discards any remaining frame payload from |reader|. - void DiscardFramePayload(QuicDataReader* reader); - // Buffers any remaining frame payload from |*reader| into |buffer_| if // necessary. Parses the frame payload if complete. Parses out of |*reader| - // without unnecessary copy if |*reader| has entire payload. + // without unnecessary copy if |*reader| contains entire payload. // Returns whether processing should continue. + // Must only be called when current frame type is buffered. bool BufferOrParsePayload(QuicDataReader* reader); // Parses the entire payload of certain kinds of frames that are parsed in a // single pass. |reader| must have at least |current_frame_length_| bytes. // Returns whether processing should continue. + // Must only be called when current frame type is buffered. bool ParseEntirePayload(QuicDataReader* reader); // Buffers any remaining frame length field from |reader| into @@ -229,26 +224,17 @@ class QUIC_EXPORT_PRIVATE HttpDecoder { // Buffers any remaining frame type field from |reader| into |type_buffer_|. void BufferFrameType(QuicDataReader* reader); - // Buffers at most |remaining_push_id_length_| from |reader| to - // |push_id_buffer_|. - void BufferPushId(QuicDataReader* reader); - // Sets |error_| and |error_detail_| accordingly. void RaiseError(QuicErrorCode error, std::string error_detail); // Parses the payload of a SETTINGS frame from |reader| into |frame|. bool ParseSettingsFrame(QuicDataReader* reader, SettingsFrame* frame); - // Parses the payload of a PRIORITY_UPDATE frame (draft-01, type 0x0f) + // Parses the payload of a PRIORITY_UPDATE frame (draft-02, type 0xf0700) // from |reader| into |frame|. bool ParsePriorityUpdateFrame(QuicDataReader* reader, PriorityUpdateFrame* frame); - // Parses the payload of a PRIORITY_UPDATE frame (draft-02, type 0xf0700) - // from |reader| into |frame|. - bool ParseNewPriorityUpdateFrame(QuicDataReader* reader, - PriorityUpdateFrame* frame); - // Parses the payload of an ACCEPT_CH frame from |reader| into |frame|. bool ParseAcceptChFrame(QuicDataReader* reader, AcceptChFrame* frame); @@ -275,10 +261,6 @@ class QUIC_EXPORT_PRIVATE HttpDecoder { QuicByteCount current_type_field_length_; // Remaining length that's needed for the frame's type field. QuicByteCount remaining_type_field_length_; - // Length of PUSH_PROMISE frame's push id. - QuicByteCount current_push_id_length_; - // Remaining length that's needed for PUSH_PROMISE frame's push id field. - QuicByteCount remaining_push_id_length_; // Last error. QuicErrorCode error_; // The issue which caused |error_| @@ -289,8 +271,6 @@ class QUIC_EXPORT_PRIVATE HttpDecoder { std::array length_buffer_; // Remaining unparsed type field data. std::array type_buffer_; - // Remaining unparsed push id data. - std::array push_id_buffer_; }; } // namespace quic diff --git a/gquiche/quic/core/http/http_decoder_test.cc b/gquiche/quic/core/http/http_decoder_test.cc index a7be107a..69672ab4 100644 --- a/gquiche/quic/core/http/http_decoder_test.cc +++ b/gquiche/quic/core/http/http_decoder_test.cc @@ -19,7 +19,6 @@ #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using ::testing::_; using ::testing::AnyNumber; @@ -28,7 +27,6 @@ using ::testing::InSequence; using ::testing::Return; namespace quic { - namespace test { class HttpDecoderPeer { @@ -38,17 +36,15 @@ class HttpDecoderPeer { } }; -class MockVisitor : public HttpDecoder::Visitor { +namespace { + +class MockHttpDecoderVisitor : public HttpDecoder::Visitor { public: - ~MockVisitor() override = default; + ~MockHttpDecoderVisitor() override = default; // Called if an error is detected. MOCK_METHOD(void, OnError, (HttpDecoder*), (override)); - MOCK_METHOD(bool, - OnCancelPushFrame, - (const CancelPushFrame& frame), - (override)); MOCK_METHOD(bool, OnMaxPushIdFrame, (const MaxPushIdFrame& frame), @@ -80,22 +76,6 @@ class MockVisitor : public HttpDecoder::Visitor { (override)); MOCK_METHOD(bool, OnHeadersFrameEnd, (), (override)); - MOCK_METHOD(bool, - OnPushPromiseFrameStart, - (QuicByteCount header_length), - (override)); - MOCK_METHOD(bool, - OnPushPromiseFramePushId, - (PushId push_id, - QuicByteCount push_id_length, - QuicByteCount header_block_length), - (override)); - MOCK_METHOD(bool, - OnPushPromiseFramePayload, - (absl::string_view payload), - (override)); - MOCK_METHOD(bool, OnPushPromiseFrameEnd, (), (override)); - MOCK_METHOD(bool, OnPriorityUpdateFrameStart, (QuicByteCount header_length), @@ -131,7 +111,6 @@ class MockVisitor : public HttpDecoder::Visitor { class HttpDecoderTest : public QuicTest { public: HttpDecoderTest() : decoder_(&visitor_) { - ON_CALL(visitor_, OnCancelPushFrame(_)).WillByDefault(Return(true)); ON_CALL(visitor_, OnMaxPushIdFrame(_)).WillByDefault(Return(true)); ON_CALL(visitor_, OnGoAwayFrame(_)).WillByDefault(Return(true)); ON_CALL(visitor_, OnSettingsFrameStart(_)).WillByDefault(Return(true)); @@ -142,11 +121,6 @@ class HttpDecoderTest : public QuicTest { ON_CALL(visitor_, OnHeadersFrameStart(_, _)).WillByDefault(Return(true)); ON_CALL(visitor_, OnHeadersFramePayload(_)).WillByDefault(Return(true)); ON_CALL(visitor_, OnHeadersFrameEnd()).WillByDefault(Return(true)); - ON_CALL(visitor_, OnPushPromiseFrameStart(_)).WillByDefault(Return(true)); - ON_CALL(visitor_, OnPushPromiseFramePushId(_, _, _)) - .WillByDefault(Return(true)); - ON_CALL(visitor_, OnPushPromiseFramePayload(_)).WillByDefault(Return(true)); - ON_CALL(visitor_, OnPushPromiseFrameEnd()).WillByDefault(Return(true)); ON_CALL(visitor_, OnPriorityUpdateFrameStart(_)) .WillByDefault(Return(true)); ON_CALL(visitor_, OnPriorityUpdateFrame(_)).WillByDefault(Return(true)); @@ -191,7 +165,7 @@ class HttpDecoderTest : public QuicTest { return processed_bytes; } - testing::StrictMock visitor_; + testing::StrictMock visitor_; HttpDecoder decoder_; }; @@ -249,124 +223,24 @@ TEST_F(HttpDecoderTest, CancelPush) { "01" // length "01"); // Push Id - // Visitor pauses processing. - EXPECT_CALL(visitor_, OnCancelPushFrame(CancelPushFrame({1}))) - .WillOnce(Return(false)); - EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input)); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the full frame. - EXPECT_CALL(visitor_, OnCancelPushFrame(CancelPushFrame({1}))); - EXPECT_EQ(input.size(), ProcessInput(input)); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the frame incrementally. - EXPECT_CALL(visitor_, OnCancelPushFrame(CancelPushFrame({1}))); - ProcessInputCharByChar(input); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(1u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ("CANCEL_PUSH frame received.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, PushPromiseFrame) { InSequence s; std::string input = - absl::StrCat(absl::HexStringToBytes("05" // type (PUSH PROMISE) - "0f" // length - "C000000000000101"), // push id 257 - "Headers"); // headers - - // Visitor pauses processing. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)).WillOnce(Return(false)); - EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8, 7)) - .WillOnce(Return(false)); - absl::string_view remaining_input(input); - QuicByteCount processed_bytes = - ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(2u, processed_bytes); - remaining_input = remaining_input.substr(processed_bytes); - processed_bytes = ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(8u, processed_bytes); - remaining_input = remaining_input.substr(processed_bytes); - - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("Headers"))) - .WillOnce(Return(false)); - processed_bytes = ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(remaining_input.size(), processed_bytes); - - EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()).WillOnce(Return(false)); - EXPECT_EQ(0u, ProcessInputWithGarbageAppended("")); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the full frame. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8, 7)); - EXPECT_CALL(visitor_, - OnPushPromiseFramePayload(absl::string_view("Headers"))); - EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); - EXPECT_EQ(input.size(), ProcessInput(input)); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the frame incrementally. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8, 7)); - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("H"))); - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("e"))); - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("a"))); - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("d"))); - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("e"))); - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("r"))); - EXPECT_CALL(visitor_, OnPushPromiseFramePayload(absl::string_view("s"))); - EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); - ProcessInputCharByChar(input); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process push id incrementally and append headers with last byte of push id. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnPushPromiseFramePushId(257, 8, 7)); - EXPECT_CALL(visitor_, - OnPushPromiseFramePayload(absl::string_view("Headers"))); - EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); - ProcessInputCharByChar(input.substr(0, 9)); - EXPECT_EQ(8u, ProcessInput(input.substr(9))); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); -} - -TEST_F(HttpDecoderTest, CorruptPushPromiseFrame) { - InSequence s; - - std::string input = absl::HexStringToBytes( - "05" // type (PUSH_PROMISE) - "01" // length - "40"); // first byte of two-byte varint push id - - { - HttpDecoder decoder(&visitor_); - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnError(&decoder)); - - decoder.ProcessInput(input.data(), input.size()); - - EXPECT_THAT(decoder.error(), IsError(QUIC_HTTP_FRAME_ERROR)); - EXPECT_EQ("Unable to read PUSH_PROMISE push_id.", decoder.error_detail()); - } - { - HttpDecoder decoder(&visitor_); - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnError(&decoder)); - - for (auto c : input) { - decoder.ProcessInput(&c, 1); - } + absl::StrCat(absl::HexStringToBytes("05" // type (PUSH PROMISE) + "08" // length + "1f"), // push id 31 + "Headers"); // headers - EXPECT_THAT(decoder.error(), IsError(QUIC_HTTP_FRAME_ERROR)); - EXPECT_EQ("Unable to read PUSH_PROMISE push_id.", decoder.error_detail()); - } + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(1u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); + EXPECT_EQ("PUSH_PROMISE frame received.", decoder_.error_detail()); } TEST_F(HttpDecoderTest, MaxPushId) { @@ -545,10 +419,8 @@ TEST_F(HttpDecoderTest, FrameHeaderPartialDelivery) { InSequence s; // A large input that will occupy more than 1 byte in the length field. std::string input(2048, 'x'); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(input.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + input.length(), SimpleBufferAllocator::Get()); // Partially send only 1 byte of the header to process. EXPECT_EQ(1u, decoder_.ProcessInput(header.data(), 1)); EXPECT_THAT(decoder_.error(), IsQuicNoError()); @@ -556,8 +428,8 @@ TEST_F(HttpDecoderTest, FrameHeaderPartialDelivery) { // Send the rest of the header. EXPECT_CALL(visitor_, OnDataFrameStart(3, input.length())); - EXPECT_EQ(header_length - 1, - decoder_.ProcessInput(header.data() + 1, header_length - 1)); + EXPECT_EQ(header.size() - 1, + decoder_.ProcessInput(header.data() + 1, header.size() - 1)); EXPECT_THAT(decoder_.error(), IsQuicNoError()); EXPECT_EQ("", decoder_.error_detail()); @@ -732,53 +604,48 @@ TEST_F(HttpDecoderTest, EmptyHeadersFrame) { EXPECT_EQ("", decoder_.error_detail()); } -TEST_F(HttpDecoderTest, PushPromiseFrameNoHeaders) { - InSequence s; +TEST_F(HttpDecoderTest, GoawayWithOverlyLargePayload) { std::string input = absl::HexStringToBytes( - "05" // type (PUSH_PROMISE) - "01" // length - "01"); // Push Id - - // Visitor pauses processing. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1, 0)) - .WillOnce(Return(false)); - EXPECT_EQ(input.size(), ProcessInputWithGarbageAppended(input)); - - EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()).WillOnce(Return(false)); - EXPECT_EQ(0u, ProcessInputWithGarbageAppended("")); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the full frame. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1, 0)); - EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); - EXPECT_EQ(input.size(), ProcessInput(input)); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the frame incrementally. - EXPECT_CALL(visitor_, OnPushPromiseFrameStart(2)); - EXPECT_CALL(visitor_, OnPushPromiseFramePushId(1, 1, 0)); - EXPECT_CALL(visitor_, OnPushPromiseFrameEnd()); - ProcessInputCharByChar(input); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); + "07" // type (GOAWAY) + "10"); // length exceeding the maximum possible length for GOAWAY frame + // Process all data at once. + EXPECT_CALL(visitor_, OnError(&decoder_)); + EXPECT_EQ(2u, ProcessInput(input)); + EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_TOO_LARGE)); + EXPECT_EQ("Frame is too large.", decoder_.error_detail()); } -TEST_F(HttpDecoderTest, MalformedFrameWithOverlyLargePayload) { +TEST_F(HttpDecoderTest, MaxPushIdWithOverlyLargePayload) { std::string input = absl::HexStringToBytes( - "03" // type (CANCEL_PUSH) - "10" // length - "15"); // malformed payload - // Process the full frame. + "0d" // type (MAX_PUSH_ID) + "10"); // length exceeding the maximum possible length for MAX_PUSH_ID + // frame + // Process all data at once. EXPECT_CALL(visitor_, OnError(&decoder_)); EXPECT_EQ(2u, ProcessInput(input)); EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_TOO_LARGE)); EXPECT_EQ("Frame is too large.", decoder_.error_detail()); } +TEST_F(HttpDecoderTest, FrameWithOverlyLargePayload) { + // Regression test for b/193919867: Ensure that reading frames with incredibly + // large payload lengths does not lead to allocating unbounded memory. + constexpr size_t max_input_length = + /*max frame type varint length*/ sizeof(uint64_t) + + /*max frame length varint length*/ sizeof(uint64_t) + + /*one byte of payload*/ sizeof(uint8_t); + char input[max_input_length]; + for (uint64_t frame_type = 0; frame_type < 1025; frame_type++) { + ::testing::NiceMock visitor; + HttpDecoder decoder(&visitor); + QuicDataWriter writer(max_input_length, input); + ASSERT_TRUE(writer.WriteVarInt62(frame_type)); // frame type. + ASSERT_TRUE(writer.WriteVarInt62(kVarInt62MaxValue)); // frame length. + ASSERT_TRUE(writer.WriteUInt8(0x00)); // one byte of payload. + EXPECT_NE(decoder.ProcessInput(input, writer.length()), 0u) << frame_type; + } +} + TEST_F(HttpDecoderTest, MalformedSettingsFrame) { char input[30]; QuicDataWriter writer(30, input); @@ -846,16 +713,7 @@ TEST_F(HttpDecoderTest, CorruptFrame) { struct { const char* const input; const char* const error_message; - } kTestData[] = {{"\x03" // type (CANCEL_PUSH) - "\x01" // length - "\x40", // first byte of two-byte varint push id - "Unable to read CANCEL_PUSH push_id."}, - {"\x03" // type (CANCEL_PUSH) - "\x04" // length - "\x05" // valid push id - "foo", // superfluous data - "Superfluous data in CANCEL_PUSH frame."}, - {"\x0D" // type (MAX_PUSH_ID) + } kTestData[] = {{"\x0D" // type (MAX_PUSH_ID) "\x01" // length "\x40", // first byte of two-byte varint push id "Unable to read MAX_PUSH_ID push_id."}, @@ -931,17 +789,6 @@ TEST_F(HttpDecoderTest, CorruptFrame) { } } -TEST_F(HttpDecoderTest, EmptyCancelPushFrame) { - std::string input = absl::HexStringToBytes( - "03" // type (CANCEL_PUSH) - "00"); // frame length - - EXPECT_CALL(visitor_, OnError(&decoder_)); - EXPECT_EQ(input.size(), ProcessInput(input)); - EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); - EXPECT_EQ("Unable to read CANCEL_PUSH push_id.", decoder_.error_detail()); -} - TEST_F(HttpDecoderTest, EmptySettingsFrame) { std::string input = absl::HexStringToBytes( "04" // type (SETTINGS) @@ -957,18 +804,6 @@ TEST_F(HttpDecoderTest, EmptySettingsFrame) { EXPECT_EQ("", decoder_.error_detail()); } -// Regression test for https://crbug.com/1001823. -TEST_F(HttpDecoderTest, EmptyPushPromiseFrame) { - std::string input = absl::HexStringToBytes( - "05" // type (PUSH_PROMISE) - "00"); // frame length - - EXPECT_CALL(visitor_, OnError(&decoder_)); - EXPECT_EQ(input.size(), ProcessInput(input)); - EXPECT_THAT(decoder_.error(), IsError(QUIC_HTTP_FRAME_ERROR)); - EXPECT_EQ("PUSH_PROMISE frame with empty payload.", decoder_.error_detail()); -} - TEST_F(HttpDecoderTest, EmptyGoAwayFrame) { std::string input = absl::HexStringToBytes( "07" // type (GOAWAY) @@ -1003,89 +838,41 @@ TEST_F(HttpDecoderTest, LargeStreamIdInGoAway) { EXPECT_EQ("", decoder_.error_detail()); } -TEST_F(HttpDecoderTest, PriorityUpdateFrame) { +// Old PRIORITY_UPDATE frame is parsed as unknown frame. +TEST_F(HttpDecoderTest, ObsoletePriorityUpdateFrame) { + const QuicByteCount header_length = 2; + const QuicByteCount payload_length = 3; InSequence s; - std::string input1 = absl::HexStringToBytes( - "0f" // type (PRIORITY_UPDATE) - "02" // length - "00" // prioritized element type: REQUEST_STREAM - "03"); // prioritized element id - - PriorityUpdateFrame priority_update1; - priority_update1.prioritized_element_type = REQUEST_STREAM; - priority_update1.prioritized_element_id = 0x03; - - // Visitor pauses processing. - EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(2)).WillOnce(Return(false)); - absl::string_view remaining_input(input1); - QuicByteCount processed_bytes = - ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(2u, processed_bytes); - remaining_input = remaining_input.substr(processed_bytes); - - EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update1)) - .WillOnce(Return(false)); - processed_bytes = ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(remaining_input.size(), processed_bytes); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the full frame. - EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(2)); - EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update1)); - EXPECT_EQ(input1.size(), ProcessInput(input1)); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - // Process the frame incrementally. - EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(2)); - EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update1)); - ProcessInputCharByChar(input1); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); - - std::string input2 = absl::HexStringToBytes( - "0f" // type (PRIORITY_UPDATE) - "05" // length - "80" // prioritized element type: PUSH_STREAM - "05" // prioritized element id - "666f6f"); // priority field value: "foo" - - PriorityUpdateFrame priority_update2; - priority_update2.prioritized_element_type = PUSH_STREAM; - priority_update2.prioritized_element_id = 0x05; - priority_update2.priority_field_value = "foo"; + std::string input = absl::HexStringToBytes( + "0f" // type (obsolete PRIORITY_UPDATE) + "03" // length + "666f6f"); // payload "foo" - // Visitor pauses processing. - EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(2)).WillOnce(Return(false)); - remaining_input = input2; - processed_bytes = ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(2u, processed_bytes); - remaining_input = remaining_input.substr(processed_bytes); + // Process frame as a whole. + EXPECT_CALL(visitor_, + OnUnknownFrameStart(0x0f, header_length, payload_length)); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("foo"))); + EXPECT_CALL(visitor_, OnUnknownFrameEnd()).WillOnce(Return(false)); - EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update2)) - .WillOnce(Return(false)); - processed_bytes = ProcessInputWithGarbageAppended(remaining_input); - EXPECT_EQ(remaining_input.size(), processed_bytes); + EXPECT_EQ(header_length + payload_length, + ProcessInputWithGarbageAppended(input)); EXPECT_THAT(decoder_.error(), IsQuicNoError()); EXPECT_EQ("", decoder_.error_detail()); - // Process the full frame. - EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(2)); - EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update2)); - EXPECT_EQ(input2.size(), ProcessInput(input2)); - EXPECT_THAT(decoder_.error(), IsQuicNoError()); - EXPECT_EQ("", decoder_.error_detail()); + // Process frame byte by byte. + EXPECT_CALL(visitor_, + OnUnknownFrameStart(0x0f, header_length, payload_length)); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("f"))); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("o"))); + EXPECT_CALL(visitor_, OnUnknownFramePayload(Eq("o"))); + EXPECT_CALL(visitor_, OnUnknownFrameEnd()); - // Process the frame incrementally. - EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(2)); - EXPECT_CALL(visitor_, OnPriorityUpdateFrame(priority_update2)); - ProcessInputCharByChar(input2); + ProcessInputCharByChar(input); EXPECT_THAT(decoder_.error(), IsQuicNoError()); EXPECT_EQ("", decoder_.error_detail()); } -TEST_F(HttpDecoderTest, NewPriorityUpdateFrame) { +TEST_F(HttpDecoderTest, PriorityUpdateFrame) { InSequence s; std::string input1 = absl::HexStringToBytes( "800f0700" // type (PRIORITY_UPDATE) @@ -1166,42 +953,6 @@ TEST_F(HttpDecoderTest, NewPriorityUpdateFrame) { } TEST_F(HttpDecoderTest, CorruptPriorityUpdateFrame) { - std::string payload1 = absl::HexStringToBytes( - "80" // prioritized element type: PUSH_STREAM - "4005"); // prioritized element id - std::string payload2 = - absl::HexStringToBytes("42"); // invalid prioritized element type - struct { - const char* const payload; - size_t payload_length; - const char* const error_message; - } kTestData[] = { - {payload1.data(), 0, "Unable to read prioritized element type."}, - {payload1.data(), 1, "Unable to read prioritized element id."}, - {payload1.data(), 2, "Unable to read prioritized element id."}, - {payload2.data(), 1, "Invalid prioritized element type."}, - }; - - for (const auto& test_data : kTestData) { - std::string input; - input.push_back(15u); // type PRIORITY_UPDATE - input.push_back(test_data.payload_length); - size_t header_length = input.size(); - input.append(test_data.payload, test_data.payload_length); - - HttpDecoder decoder(&visitor_); - EXPECT_CALL(visitor_, OnPriorityUpdateFrameStart(header_length)); - EXPECT_CALL(visitor_, OnError(&decoder)); - - QuicByteCount processed_bytes = - decoder.ProcessInput(input.data(), input.size()); - EXPECT_EQ(input.size(), processed_bytes); - EXPECT_THAT(decoder.error(), IsError(QUIC_HTTP_FRAME_ERROR)); - EXPECT_EQ(test_data.error_message, decoder.error_detail()); - } -} - -TEST_F(HttpDecoderTest, CorruptNewPriorityUpdateFrame) { std::string payload = absl::HexStringToBytes("4005"); // prioritized element id struct { @@ -1318,7 +1069,7 @@ TEST_F(HttpDecoderTest, WebTransportStreamDisabled) { TEST(HttpDecoderTestNoFixture, WebTransportStream) { HttpDecoder::Options options; options.allow_web_transport_stream = true; - testing::StrictMock visitor; + testing::StrictMock visitor; HttpDecoder decoder(&visitor, options); // WebTransport stream for session ID 0x104, with four bytes of extra data. @@ -1331,7 +1082,7 @@ TEST(HttpDecoderTestNoFixture, WebTransportStream) { TEST(HttpDecoderTestNoFixture, WebTransportStreamError) { HttpDecoder::Options options; options.allow_web_transport_stream = true; - testing::StrictMock visitor; + testing::StrictMock visitor; HttpDecoder decoder(&visitor, options); std::string input = absl::HexStringToBytes("404100"); @@ -1380,6 +1131,6 @@ TEST_F(HttpDecoderTest, DecodeSettings) { EXPECT_FALSE(HttpDecoder::DecodeSettings(input.data(), input.size(), &out)); } +} // namespace } // namespace test - } // namespace quic diff --git a/gquiche/quic/core/http/http_encoder.cc b/gquiche/quic/core/http/http_encoder.cc index 934d81a2..2c5d0eca 100644 --- a/gquiche/quic/core/http/http_encoder.cc +++ b/gquiche/quic/core/http/http_encoder.cc @@ -34,23 +34,30 @@ QuicByteCount GetTotalLength(QuicByteCount payload_length, HttpFrameType type) { } // namespace // static -QuicByteCount HttpEncoder::SerializeDataFrameHeader( +QuicByteCount HttpEncoder::GetDataFrameHeaderLength( + QuicByteCount payload_length) { + QUICHE_DCHECK_NE(0u, payload_length); + return QuicDataWriter::GetVarInt62Len(payload_length) + + QuicDataWriter::GetVarInt62Len( + static_cast(HttpFrameType::DATA)); +} + +// static +QuicBuffer HttpEncoder::SerializeDataFrameHeader( QuicByteCount payload_length, - std::unique_ptr* output) { + QuicBufferAllocator* allocator) { QUICHE_DCHECK_NE(0u, payload_length); - QuicByteCount header_length = QuicDataWriter::GetVarInt62Len(payload_length) + - QuicDataWriter::GetVarInt62Len( - static_cast(HttpFrameType::DATA)); + QuicByteCount header_length = GetDataFrameHeaderLength(payload_length); - output->reset(new char[header_length]); - QuicDataWriter writer(header_length, output->get()); + QuicBuffer header(allocator, header_length); + QuicDataWriter writer(header.size(), header.data()); if (WriteFrameHeader(payload_length, HttpFrameType::DATA, &writer)) { - return header_length; + return header; } QUIC_DLOG(ERROR) << "Http encoder failed when attempting to serialize data frame header."; - return 0; + return QuicBuffer(); } // static diff --git a/gquiche/quic/core/http/http_encoder.h b/gquiche/quic/core/http/http_encoder.h index 2c9bcdc8..1a11dc15 100644 --- a/gquiche/quic/core/http/http_encoder.h +++ b/gquiche/quic/core/http/http_encoder.h @@ -7,6 +7,7 @@ #include #include "gquiche/quic/core/http/http_frames.h" +#include "gquiche/quic/core/quic_buffer_allocator.h" #include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_export.h" @@ -21,11 +22,13 @@ class QUIC_EXPORT_PRIVATE HttpEncoder { public: HttpEncoder() = delete; - // Serializes a DATA frame header into a new buffer stored in |output|. - // Returns the length of the buffer on success, or 0 otherwise. - static QuicByteCount SerializeDataFrameHeader( - QuicByteCount payload_length, - std::unique_ptr* output); + // Returns the length of the header for a DATA frame. + static QuicByteCount GetDataFrameHeaderLength(QuicByteCount payload_length); + + // Serializes a DATA frame header into a QuicBuffer; returns said QuicBuffer + // on success, empty buffer otherwise. + static QuicBuffer SerializeDataFrameHeader(QuicByteCount payload_length, + QuicBufferAllocator* allocator); // Serializes a HEADERS frame header into a new buffer stored in |output|. // Returns the length of the buffer on success, or 0 otherwise. diff --git a/gquiche/quic/core/http/http_encoder_test.cc b/gquiche/quic/core/http/http_encoder_test.cc index bf9474ed..732c5630 100644 --- a/gquiche/quic/core/http/http_encoder_test.cc +++ b/gquiche/quic/core/http/http_encoder_test.cc @@ -5,6 +5,7 @@ #include "gquiche/quic/core/http/http_encoder.h" #include "absl/base/macros.h" +#include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" @@ -14,16 +15,15 @@ namespace quic { namespace test { TEST(HttpEncoderTest, SerializeDataFrameHeader) { - std::unique_ptr buffer; - uint64_t length = - HttpEncoder::SerializeDataFrameHeader(/* payload_length = */ 5, &buffer); + QuicBuffer buffer = HttpEncoder::SerializeDataFrameHeader( + /* payload_length = */ 5, SimpleBufferAllocator::Get()); char output[] = {// type (DATA) 0x00, // length 0x05}; - EXPECT_EQ(ABSL_ARRAYSIZE(output), length); - quiche::test::CompareCharArraysWithHexError("DATA", buffer.get(), length, - output, ABSL_ARRAYSIZE(output)); + EXPECT_EQ(ABSL_ARRAYSIZE(output), buffer.size()); + quiche::test::CompareCharArraysWithHexError( + "DATA", buffer.data(), buffer.size(), output, ABSL_ARRAYSIZE(output)); } TEST(HttpEncoderTest, SerializeHeadersFrameHeader) { @@ -87,40 +87,42 @@ TEST(HttpEncoderTest, SerializePriorityUpdateFrame) { PriorityUpdateFrame priority_update1; priority_update1.prioritized_element_type = REQUEST_STREAM; priority_update1.prioritized_element_id = 0x03; - char output1[] = {0x80, 0x0f, 0x07, 0x00, // type (PRIORITY_UPDATE) - 0x01, // length - 0x03}; // prioritized element id + uint8_t output1[] = {0x80, 0x0f, 0x07, 0x00, // type (PRIORITY_UPDATE) + 0x01, // length + 0x03}; // prioritized element id std::unique_ptr buffer; uint64_t length = HttpEncoder::SerializePriorityUpdateFrame(priority_update1, &buffer); EXPECT_EQ(ABSL_ARRAYSIZE(output1), length); - quiche::test::CompareCharArraysWithHexError("PRIORITY_UPDATE", buffer.get(), - length, output1, - ABSL_ARRAYSIZE(output1)); + quiche::test::CompareCharArraysWithHexError( + "PRIORITY_UPDATE", buffer.get(), length, reinterpret_cast(output1), + ABSL_ARRAYSIZE(output1)); } TEST(HttpEncoderTest, SerializeAcceptChFrame) { AcceptChFrame accept_ch; - char output1[] = {0x40, 0x89, // type (ACCEPT_CH) - 0x00}; // length + uint8_t output1[] = {0x40, 0x89, // type (ACCEPT_CH) + 0x00}; // length std::unique_ptr buffer; uint64_t length = HttpEncoder::SerializeAcceptChFrame(accept_ch, &buffer); EXPECT_EQ(ABSL_ARRAYSIZE(output1), length); quiche::test::CompareCharArraysWithHexError("ACCEPT_CH", buffer.get(), length, - output1, ABSL_ARRAYSIZE(output1)); + reinterpret_cast(output1), + ABSL_ARRAYSIZE(output1)); accept_ch.entries.push_back({"foo", "bar"}); - char output2[] = {0x40, 0x89, // type (ACCEPT_CH) - 0x08, // payload length - 0x03, 0x66, 0x6f, 0x6f, // length of "foo"; "foo" - 0x03, 0x62, 0x61, 0x72}; // length of "bar"; "bar" + uint8_t output2[] = {0x40, 0x89, // type (ACCEPT_CH) + 0x08, // payload length + 0x03, 0x66, 0x6f, 0x6f, // length of "foo"; "foo" + 0x03, 0x62, 0x61, 0x72}; // length of "bar"; "bar" length = HttpEncoder::SerializeAcceptChFrame(accept_ch, &buffer); EXPECT_EQ(ABSL_ARRAYSIZE(output2), length); quiche::test::CompareCharArraysWithHexError("ACCEPT_CH", buffer.get(), length, - output2, ABSL_ARRAYSIZE(output2)); + reinterpret_cast(output2), + ABSL_ARRAYSIZE(output2)); } TEST(HttpEncoderTest, SerializeWebTransportStreamFrameHeader) { diff --git a/gquiche/quic/core/http/http_frames.h b/gquiche/quic/core/http/http_frames.h index dcff2743..206911c0 100644 --- a/gquiche/quic/core/http/http_frames.h +++ b/gquiche/quic/core/http/http_frames.h @@ -20,6 +20,9 @@ namespace quic { +// TODO(b/171463363): Remove. +using PushId = uint64_t; + enum class HttpFrameType { DATA = 0x0, HEADERS = 0x1, @@ -28,11 +31,9 @@ enum class HttpFrameType { PUSH_PROMISE = 0x5, GOAWAY = 0x7, MAX_PUSH_ID = 0xD, - // https://tools.ietf.org/html/draft-ietf-httpbis-priority-01 - PRIORITY_UPDATE = 0XF, // https://tools.ietf.org/html/draft-davidben-http-client-hint-reliability-02 ACCEPT_CH = 0x89, - // https://tools.ietf.org/html/draft-ietf-httpbis-priority-02 + // https://tools.ietf.org/html/draft-ietf-httpbis-priority-03 PRIORITY_UPDATE_REQUEST_STREAM = 0xF0700, // https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-00.html WEBTRANSPORT_STREAM = 0x41, @@ -54,20 +55,6 @@ struct QUIC_EXPORT_PRIVATE HeadersFrame { absl::string_view headers; }; -// 7.2.3. CANCEL_PUSH -// -// The CANCEL_PUSH frame (type=0x3) is used to request cancellation of -// server push prior to the push stream being created. -using PushId = uint64_t; - -struct QUIC_EXPORT_PRIVATE CancelPushFrame { - PushId push_id; - - bool operator==(const CancelPushFrame& rhs) const { - return push_id == rhs.push_id; - } -}; - // 7.2.4. SETTINGS // // The SETTINGS frame (type=0x4) conveys configuration parameters that @@ -101,20 +88,6 @@ struct QUIC_EXPORT_PRIVATE SettingsFrame { } }; -// 7.2.5. PUSH_PROMISE -// -// The PUSH_PROMISE frame (type=0x05) is used to carry a request header -// set from server to client, as in HTTP/2. -// TODO(b/171463363): Remove. -struct QUIC_EXPORT_PRIVATE PushPromiseFrame { - PushId push_id; - absl::string_view headers; - - bool operator==(const PushPromiseFrame& rhs) const { - return push_id == rhs.push_id && headers == rhs.headers; - } -}; - // 7.2.6. GOAWAY // // The GOAWAY frame (type=0x7) is used to initiate shutdown of a connection by @@ -143,10 +116,9 @@ struct QUIC_EXPORT_PRIVATE MaxPushIdFrame { // https://httpwg.org/http-extensions/draft-ietf-httpbis-priority.html // // The PRIORITY_UPDATE frame specifies the sender-advised priority of a stream. -// https://tools.ietf.org/html/draft-ietf-httpbis-priority-01 uses frame type -// 0x0f, both for request streams and push streams. -// https://tools.ietf.org/html/draft-ietf-httpbis-priority-02 uses frame types -// 0xf0700 for request streams and 0xf0701 for push streams (not implemented). +// Frame type 0xf0700 (called PRIORITY_UPDATE_REQUEST_STREAM in the +// implementation) is used for for request streams. +// Frame type 0xf0701 is used for push streams and is not implemented. // Length of a priority frame's first byte. const QuicByteCount kPriorityFirstByteLength = 1; diff --git a/gquiche/quic/core/http/http_frames_test.cc b/gquiche/quic/core/http/http_frames_test.cc index 36d777c4..f790aa65 100644 --- a/gquiche/quic/core/http/http_frames_test.cc +++ b/gquiche/quic/core/http/http_frames_test.cc @@ -12,17 +12,6 @@ namespace quic { namespace test { -TEST(HttpFramesTest, CancelPushFrame) { - CancelPushFrame a{1}; - EXPECT_TRUE(a == a); - - CancelPushFrame b{1}; - EXPECT_TRUE(a == b); - - b.push_id = 2; - EXPECT_FALSE(a == b); -} - TEST(HttpFramesTest, SettingsFrame) { SettingsFrame a; EXPECT_TRUE(a == a); @@ -44,21 +33,6 @@ TEST(HttpFramesTest, SettingsFrame) { EXPECT_EQ("SETTINGS_QPACK_MAX_TABLE_CAPACITY = 1; ", s.str()); } -TEST(HttpFramesTest, PushPromiseFrame) { - PushPromiseFrame a{1, ""}; - EXPECT_TRUE(a == a); - - PushPromiseFrame b{2, ""}; - EXPECT_FALSE(a == b); - - b.push_id = 1; - EXPECT_TRUE(a == b); - - b.headers = "foo"; - EXPECT_FALSE(a == b); - EXPECT_TRUE(b == b); -} - TEST(HttpFramesTest, GoAwayFrame) { GoAwayFrame a{1}; EXPECT_TRUE(a == a); diff --git a/gquiche/quic/core/http/quic_client_promised_info.cc b/gquiche/quic/core/http/quic_client_promised_info.cc index c1713bfc..91c36e19 100644 --- a/gquiche/quic/core/http/quic_client_promised_info.cc +++ b/gquiche/quic/core/http/quic_client_promised_info.cc @@ -24,7 +24,11 @@ QuicClientPromisedInfo::QuicClientPromisedInfo( url_(std::move(url)), client_request_delegate_(nullptr) {} -QuicClientPromisedInfo::~QuicClientPromisedInfo() {} +QuicClientPromisedInfo::~QuicClientPromisedInfo() { + if (cleanup_alarm_ != nullptr) { + cleanup_alarm_->PermanentCancel(); + } +} void QuicClientPromisedInfo::CleanupAlarm::OnAlarm() { QUIC_DVLOG(1) << "self GC alarm for stream " << promised_->id_; diff --git a/gquiche/quic/core/http/quic_client_promised_info.h b/gquiche/quic/core/http/quic_client_promised_info.h index d7b4c91d..5487e0a6 100644 --- a/gquiche/quic/core/http/quic_client_promised_info.h +++ b/gquiche/quic/core/http/quic_client_promised_info.h @@ -84,7 +84,8 @@ class QUIC_EXPORT_PRIVATE QuicClientPromisedInfo private: friend class test::QuicClientPromisedInfoPeer; - class QUIC_EXPORT_PRIVATE CleanupAlarm : public QuicAlarm::Delegate { + class QUIC_EXPORT_PRIVATE CleanupAlarm + : public QuicAlarm::DelegateWithoutContext { public: explicit CleanupAlarm(QuicClientPromisedInfo* promised) : promised_(promised) {} diff --git a/gquiche/quic/core/http/quic_header_list.h b/gquiche/quic/core/http/quic_header_list.h index a7c939a9..966411c4 100644 --- a/gquiche/quic/core/http/quic_header_list.h +++ b/gquiche/quic/core/http/quic_header_list.h @@ -11,9 +11,9 @@ #include #include "absl/strings/string_view.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/common/quiche_circular_deque.h" #include "gquiche/spdy/core/spdy_header_block.h" #include "gquiche/spdy/core/spdy_headers_handler_interface.h" @@ -23,7 +23,8 @@ namespace quic { class QUIC_EXPORT_PRIVATE QuicHeaderList : public spdy::SpdyHeadersHandlerInterface { public: - using ListType = QuicCircularDeque>; + using ListType = + quiche::QuicheCircularDeque>; using value_type = ListType::value_type; using const_iterator = ListType::const_iterator; @@ -59,7 +60,7 @@ class QUIC_EXPORT_PRIVATE QuicHeaderList std::string DebugString() const; private: - QuicCircularDeque> header_list_; + quiche::QuicheCircularDeque> header_list_; // The limit on the size of the header list (defined by spec as name + value + // overhead for each header field). Headers over this limit will not be diff --git a/gquiche/quic/core/http/quic_headers_stream.h b/gquiche/quic/core/http/quic_headers_stream.h index 435593ca..f4c0282e 100644 --- a/gquiche/quic/core/http/quic_headers_stream.h +++ b/gquiche/quic/core/http/quic_headers_stream.h @@ -91,7 +91,7 @@ class QUIC_EXPORT_PRIVATE QuicHeadersStream : public QuicStream { QuicSpdySession* spdy_session_; // Headers that have not been fully acked. - QuicCircularDeque unacked_headers_; + quiche::QuicheCircularDeque unacked_headers_; }; } // namespace quic diff --git a/gquiche/quic/core/http/quic_headers_stream_test.cc b/gquiche/quic/core/http/quic_headers_stream_test.cc index 10080cd6..b2a8915d 100644 --- a/gquiche/quic/core/http/quic_headers_stream_test.cc +++ b/gquiche/quic/core/http/quic_headers_stream_test.cc @@ -563,6 +563,10 @@ TEST_P(QuicHeadersStreamTest, ProcessPriorityFrame) { } TEST_P(QuicHeadersStreamTest, ProcessPushPromiseDisabledSetting) { + if (perspective() != Perspective::IS_CLIENT) { + return; + } + session_.OnConfigNegotiated(); SpdySettingsIR data; // Respect supported settings frames SETTINGS_ENABLE_PUSH. @@ -570,15 +574,11 @@ TEST_P(QuicHeadersStreamTest, ProcessPushPromiseDisabledSetting) { SpdySerializedFrame frame(framer_->SerializeFrame(data)); stream_frame_.data_buffer = frame.data(); stream_frame_.data_length = frame.size(); - if (perspective() == Perspective::IS_CLIENT) { - EXPECT_CALL( - *connection_, - CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, - "Unsupported field of HTTP/2 SETTINGS frame: 2", _)); - } + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_HEADERS_STREAM_DATA, + "Unsupported field of HTTP/2 SETTINGS frame: 2", _)); headers_stream_->OnStreamFrame(stream_frame_); - EXPECT_EQ(session_.server_push_enabled(), - perspective() == Perspective::IS_CLIENT); } TEST_P(QuicHeadersStreamTest, ProcessLargeRawData) { diff --git a/gquiche/quic/core/http/quic_receive_control_stream.cc b/gquiche/quic/core/http/quic_receive_control_stream.cc index cb27301f..90483539 100644 --- a/gquiche/quic/core/http/quic_receive_control_stream.cc +++ b/gquiche/quic/core/http/quic_receive_control_stream.cc @@ -15,7 +15,7 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { @@ -24,7 +24,6 @@ QuicReceiveControlStream::QuicReceiveControlStream( QuicSpdySession* spdy_session) : QuicStream(pending, spdy_session, - READ_UNIDIRECTIONAL, /*is_static=*/true), settings_frame_received_(false), decoder_(this), @@ -65,15 +64,6 @@ void QuicReceiveControlStream::OnError(HttpDecoder* decoder) { stream_delegate()->OnStreamError(decoder->error(), decoder->error_detail()); } -bool QuicReceiveControlStream::OnCancelPushFrame(const CancelPushFrame& frame) { - if (spdy_session()->debug_visitor()) { - spdy_session()->debug_visitor()->OnCancelPushFrameReceived(frame); - } - - // TODO(b/151841240): Handle CANCEL_PUSH frames instead of ignoring them. - return ValidateFrameType(HttpFrameType::CANCEL_PUSH); -} - bool QuicReceiveControlStream::OnMaxPushIdFrame(const MaxPushIdFrame& frame) { if (spdy_session()->debug_visitor()) { spdy_session()->debug_visitor()->OnMaxPushIdFrameReceived(frame); @@ -145,33 +135,9 @@ bool QuicReceiveControlStream::OnHeadersFrameEnd() { return false; } -bool QuicReceiveControlStream::OnPushPromiseFrameStart( - QuicByteCount /*header_length*/) { - return ValidateFrameType(HttpFrameType::PUSH_PROMISE); -} - -bool QuicReceiveControlStream::OnPushPromiseFramePushId( - PushId /*push_id*/, - QuicByteCount /*push_id_length*/, - QuicByteCount /*header_block_length*/) { - QUICHE_NOTREACHED(); - return false; -} - -bool QuicReceiveControlStream::OnPushPromiseFramePayload( - absl::string_view /*payload*/) { - QUICHE_NOTREACHED(); - return false; -} - -bool QuicReceiveControlStream::OnPushPromiseFrameEnd() { - QUICHE_NOTREACHED(); - return false; -} - bool QuicReceiveControlStream::OnPriorityUpdateFrameStart( QuicByteCount /*header_length*/) { - return ValidateFrameType(HttpFrameType::PRIORITY_UPDATE); + return ValidateFrameType(HttpFrameType::PRIORITY_UPDATE_REQUEST_STREAM); } bool QuicReceiveControlStream::OnPriorityUpdateFrame( @@ -265,9 +231,8 @@ bool QuicReceiveControlStream::OnUnknownFrameEnd() { bool QuicReceiveControlStream::ValidateFrameType(HttpFrameType frame_type) { // Certain frame types are forbidden. - if ((frame_type == HttpFrameType::DATA || - frame_type == HttpFrameType::HEADERS || - frame_type == HttpFrameType::PUSH_PROMISE) || + if (frame_type == HttpFrameType::DATA || + frame_type == HttpFrameType::HEADERS || (spdy_session()->perspective() == Perspective::IS_CLIENT && frame_type == HttpFrameType::MAX_PUSH_ID) || (spdy_session()->perspective() == Perspective::IS_SERVER && diff --git a/gquiche/quic/core/http/quic_receive_control_stream.h b/gquiche/quic/core/http/quic_receive_control_stream.h index b24b306b..a3044d7c 100644 --- a/gquiche/quic/core/http/quic_receive_control_stream.h +++ b/gquiche/quic/core/http/quic_receive_control_stream.h @@ -33,9 +33,8 @@ class QUIC_EXPORT_PRIVATE QuicReceiveControlStream // Implementation of QuicStream. void OnDataAvailable() override; - // HttpDecoderVisitor implementation. + // HttpDecoder::Visitor implementation. void OnError(HttpDecoder* decoder) override; - bool OnCancelPushFrame(const CancelPushFrame& frame) override; bool OnMaxPushIdFrame(const MaxPushIdFrame& frame) override; bool OnGoAwayFrame(const GoAwayFrame& frame) override; bool OnSettingsFrameStart(QuicByteCount header_length) override; @@ -48,12 +47,6 @@ class QUIC_EXPORT_PRIVATE QuicReceiveControlStream QuicByteCount payload_length) override; bool OnHeadersFramePayload(absl::string_view payload) override; bool OnHeadersFrameEnd() override; - bool OnPushPromiseFrameStart(QuicByteCount header_length) override; - bool OnPushPromiseFramePushId(PushId push_id, - QuicByteCount push_id_length, - QuicByteCount header_block_length) override; - bool OnPushPromiseFramePayload(absl::string_view payload) override; - bool OnPushPromiseFrameEnd() override; bool OnPriorityUpdateFrameStart(QuicByteCount header_length) override; bool OnPriorityUpdateFrame(const PriorityUpdateFrame& frame) override; bool OnAcceptChFrameStart(QuicByteCount header_length) override; diff --git a/gquiche/quic/core/http/quic_receive_control_stream_test.cc b/gquiche/quic/core/http/quic_receive_control_stream_test.cc index 4c814a63..ff1c6836 100644 --- a/gquiche/quic/core/http/quic_receive_control_stream_test.cc +++ b/gquiche/quic/core/http/quic_receive_control_stream_test.cc @@ -9,13 +9,13 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/http/http_constants.h" #include "gquiche/quic/core/qpack/qpack_header_table.h" +#include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_encoder_peer.h" #include "gquiche/quic/test_tools/quic_spdy_session_peer.h" #include "gquiche/quic/test_tools/quic_stream_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { @@ -240,12 +240,11 @@ TEST_P(QuicReceiveControlStreamTest, ReceiveSettingsFragments) { TEST_P(QuicReceiveControlStreamTest, ReceiveWrongFrame) { // DATA frame header without payload. - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(/* payload_length = */ 2, &buffer); - std::string data = std::string(buffer.get(), header_length); + QuicBuffer data = HttpEncoder::SerializeDataFrameHeader( + /* payload_length = */ 2, SimpleBufferAllocator::Get()); - QuicStreamFrame frame(receive_control_stream_->id(), false, 1, data); + QuicStreamFrame frame(receive_control_stream_->id(), false, 1, + data.AsStringView()); EXPECT_CALL( *connection_, CloseConnection(QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM, _, _)); @@ -261,7 +260,7 @@ TEST_P(QuicReceiveControlStreamTest, EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_MISSING_SETTINGS_FRAME, "First frame received on control stream is type " - "15, but it must be SETTINGS.", + "984832, but it must be SETTINGS.", _)) .WillOnce( Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); @@ -309,9 +308,7 @@ TEST_P(QuicReceiveControlStreamTest, PushPromiseOnControlStreamShouldClose) { "00"); // push ID QuicStreamFrame frame(receive_control_stream_->id(), false, 1, push_promise_frame); - EXPECT_CALL( - *connection_, - CloseConnection(QUIC_HTTP_FRAME_UNEXPECTED_ON_CONTROL_STREAM, _, _)) + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, _, _)) .WillOnce( Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); @@ -382,11 +379,8 @@ TEST_P(QuicReceiveControlStreamTest, CancelPushFrameBeforeSettings) { "01" // payload length "01"); // push ID - EXPECT_CALL(*connection_, - CloseConnection(QUIC_HTTP_MISSING_SETTINGS_FRAME, - "First frame received on control stream is type " - "3, but it must be SETTINGS.", - _)) + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, + "CANCEL_PUSH frame received.", _)) .WillOnce( Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); diff --git a/gquiche/quic/core/http/quic_send_control_stream.cc b/gquiche/quic/core/http/quic_send_control_stream.cc index 6afeaa1d..ce587138 100644 --- a/gquiche/quic/core/http/quic_send_control_stream.cc +++ b/gquiche/quic/core/http/quic_send_control_stream.cc @@ -31,7 +31,7 @@ void QuicSendControlStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { << "OnStreamReset() called for write unidirectional stream."; } -bool QuicSendControlStream::OnStopSending(QuicRstStreamErrorCode /* code */) { +bool QuicSendControlStream::OnStopSending(QuicResetStreamError /* code */) { stream_delegate()->OnStreamError( QUIC_HTTP_CLOSED_CRITICAL_STREAM, "STOP_SENDING received for send control stream"); diff --git a/gquiche/quic/core/http/quic_send_control_stream.h b/gquiche/quic/core/http/quic_send_control_stream.h index 78556a79..7b458820 100644 --- a/gquiche/quic/core/http/quic_send_control_stream.h +++ b/gquiche/quic/core/http/quic_send_control_stream.h @@ -31,7 +31,7 @@ class QUIC_EXPORT_PRIVATE QuicSendControlStream : public QuicStream { // Overriding QuicStream::OnStopSending() to make sure control stream is never // closed before connection. void OnStreamReset(const QuicRstStreamFrame& frame) override; - bool OnStopSending(QuicRstStreamErrorCode code) override; + bool OnStopSending(QuicResetStreamError code) override; // Send SETTINGS frame if it hasn't been sent yet. Settings frame must be the // first frame sent on this stream. diff --git a/gquiche/quic/core/http/quic_send_control_stream_test.cc b/gquiche/quic/core/http/quic_send_control_stream_test.cc index 1b7412d7..672a2a21 100644 --- a/gquiche/quic/core/http/quic_send_control_stream_test.cc +++ b/gquiche/quic/core/http/quic_send_control_stream_test.cc @@ -13,7 +13,6 @@ #include "gquiche/quic/test_tools/quic_config_peer.h" #include "gquiche/quic/test_tools/quic_spdy_session_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace quic { @@ -72,9 +71,7 @@ class QuicSendControlStreamTest : public QuicTestWithParam { public: QuicSendControlStreamTest() : connection_(new StrictMock( - &helper_, - &alarm_factory_, - perspective(), + &helper_, &alarm_factory_, perspective(), SupportedVersions(GetParam().version))), session_(connection_) { ON_CALL(session_, WritevData(_, _, _, _, _, _)) @@ -105,8 +102,7 @@ class QuicSendControlStreamTest : public QuicTestWithParam { QuicSendControlStream* send_control_stream_; }; -INSTANTIATE_TEST_SUITE_P(Tests, - QuicSendControlStreamTest, +INSTANTIATE_TEST_SUITE_P(Tests, QuicSendControlStreamTest, ::testing::ValuesIn(GetTestParams()), ::testing::PrintToStringParamName()); @@ -134,25 +130,78 @@ TEST_P(QuicSendControlStreamTest, WriteSettings) { "4040" // 0x40 as the reserved frame type "01" // 1 byte frame length "61"); // payload "a" - if (GetQuicReloadableFlag(quic_h3_datagram)) { + if ((!GetQuicReloadableFlag(quic_verify_request_headers_2) || + perspective() == Perspective::IS_CLIENT) && + QuicSpdySessionPeer::LocalHttpDatagramSupport(&session_) == + HttpDatagramSupport::kDraft00And04) { + expected_write_data = absl::HexStringToBytes( + "00" // stream type: control stream + "04" // frame type: SETTINGS frame + "0e" // frame length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "40ff" // 255 + "06" // SETTINGS_MAX_HEADER_LIST_SIZE + "4400" // 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "10" // 16 + "4040" // 0x40 as the reserved settings id + "14" // 20 + "4276" // SETTINGS_H3_DATAGRAM_DRAFT00 + "01" // 1 + "800ffd277" // SETTINGS_H3_DATAGRAM_DRAFT04 + "01" // 1 + "4040" // 0x40 as the reserved frame type + "01" // 1 byte frame length + "61"); // payload "a" + } + if (GetQuicReloadableFlag(quic_verify_request_headers_2) && + perspective() == Perspective::IS_SERVER && + QuicSpdySessionPeer::LocalHttpDatagramSupport(&session_) == + HttpDatagramSupport::kNone) { expected_write_data = absl::HexStringToBytes( "00" // stream type: control stream "04" // frame type: SETTINGS frame - "0e" // frame length + "0d" // frame length "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY "40ff" // 255 "06" // SETTINGS_MAX_HEADER_LIST_SIZE "4400" // 1024 "07" // SETTINGS_QPACK_BLOCKED_STREAMS "10" // 16 + "08" // SETTINGS_ENABLE_CONNECT_PROTOCOL + "01" // 1 "4040" // 0x40 as the reserved settings id "14" // 20 - "4276" // SETTINGS_H3_DATAGRAM - "01" // 1 "4040" // 0x40 as the reserved frame type "01" // 1 byte frame length "61"); // payload "a" } + if (GetQuicReloadableFlag(quic_verify_request_headers_2) && + perspective() == Perspective::IS_SERVER && + QuicSpdySessionPeer::LocalHttpDatagramSupport(&session_) != + HttpDatagramSupport::kNone) { + expected_write_data = absl::HexStringToBytes( + "00" // stream type: control stream + "04" // frame type: SETTINGS frame + "11" // frame length + "01" // SETTINGS_QPACK_MAX_TABLE_CAPACITY + "40ff" // 255 + "06" // SETTINGS_MAX_HEADER_LIST_SIZE + "4400" // 1024 + "07" // SETTINGS_QPACK_BLOCKED_STREAMS + "10" // 16 + "08" // SETTINGS_ENABLE_CONNECT_PROTOCOL + "01" // 1 + "4040" // 0x40 as the reserved settings id + "14" // 20 + "4276" // SETTINGS_H3_DATAGRAM_DRAFT00 + "01" // 1 + "800ffd277" // SETTINGS_H3_DATAGRAM_DRAFT04 + "01" // 1 + "4040" // 0x40 as the reserved frame type + "01" // 1 byte frame length + "61"); // payload "a" + } auto buffer = std::make_unique(expected_write_data.size()); QuicDataWriter writer(expected_write_data.size(), buffer.get()); @@ -217,7 +266,8 @@ TEST_P(QuicSendControlStreamTest, CloseControlStream) { Initialize(); EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, _, _)); - send_control_stream_->OnStopSending(QUIC_STREAM_CANCELLED); + send_control_stream_->OnStopSending( + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED)); } TEST_P(QuicSendControlStreamTest, ReceiveDataOnSendControlStream) { diff --git a/gquiche/quic/core/http/quic_server_session_base.cc b/gquiche/quic/core/http/quic_server_session_base.cc index 70054692..cf19ed83 100644 --- a/gquiche/quic/core/http/quic_server_session_base.cc +++ b/gquiche/quic/core/http/quic_server_session_base.cc @@ -10,6 +10,7 @@ #include "gquiche/quic/core/quic_connection.h" #include "gquiche/quic/core/quic_stream.h" #include "gquiche/quic/core/quic_tag.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" @@ -50,10 +51,28 @@ void QuicServerSessionBase::OnConfigNegotiated() { return; } - // Disable server push if peer sends the corresponding connection option. - if (!version().UsesHttp3() && - ContainsQuicTag(config()->ReceivedConnectionOptions(), kQNSP)) { - OnSetting(spdy::SETTINGS_ENABLE_PUSH, 0); + const CachedNetworkParameters* cached_network_params = + crypto_stream_->PreviousCachedNetworkParams(); + + // Set the initial rtt from cached_network_params.min_rtt_ms, which comes from + // a validated address token. This will override the initial rtt that may have + // been set by the transport parameters. + if (add_cached_network_parameters_to_address_token() && version().UsesTls() && + cached_network_params != nullptr) { + if (cached_network_params->serving_region() == serving_region_) { + QUIC_CODE_COUNT(quic_server_received_network_params_at_same_region); + if (ContainsQuicTag(config()->ReceivedConnectionOptions(), kTRTT)) { + QUIC_DLOG(INFO) + << "Server: Setting initial rtt to " + << cached_network_params->min_rtt_ms() + << "ms which is received from a validated address token"; + connection()->sent_packet_manager().SetInitialRtt( + QuicTime::Delta::FromMilliseconds( + cached_network_params->min_rtt_ms())); + } + } else { + QUIC_CODE_COUNT(quic_server_received_network_params_at_different_region); + } } // Enable bandwidth resumption if peer sent correct connection options. @@ -67,13 +86,14 @@ void QuicServerSessionBase::OnConfigNegotiated() { // If the client has provided a bandwidth estimate from the same serving // region as this server, then decide whether to use the data for bandwidth // resumption. - const CachedNetworkParameters* cached_network_params = - crypto_stream_->PreviousCachedNetworkParams(); if (cached_network_params != nullptr && cached_network_params->serving_region() == serving_region_) { - // Log the received connection parameters, regardless of how they - // get used for bandwidth resumption. - connection()->OnReceiveConnectionState(*cached_network_params); + if (!add_cached_network_parameters_to_address_token() || + !version().UsesTls()) { + // Log the received connection parameters, regardless of how they + // get used for bandwidth resumption. + connection()->OnReceiveConnectionState(*cached_network_params); + } if (bandwidth_resumption_enabled_) { // Only do bandwidth resumption if estimate is recent enough. @@ -156,48 +176,75 @@ void QuicServerSessionBase::OnCongestionWindowChange(QuicTime now) { return; } - bandwidth_estimate_sent_to_client_ = new_bandwidth_estimate; - QUIC_DVLOG(1) << "Server: sending new bandwidth estimate (KBytes/s): " - << bandwidth_estimate_sent_to_client_.ToKBytesPerSecond(); - - // Include max bandwidth in the update. - QuicBandwidth max_bandwidth_estimate = - bandwidth_recorder->MaxBandwidthEstimate(); - int32_t max_bandwidth_timestamp = bandwidth_recorder->MaxBandwidthTimestamp(); - - // Fill the proto before passing it to the crypto stream to send. - const int32_t bw_estimate_bytes_per_second = - BandwidthToCachedParameterBytesPerSecond( - bandwidth_estimate_sent_to_client_); - const int32_t max_bw_estimate_bytes_per_second = - BandwidthToCachedParameterBytesPerSecond(max_bandwidth_estimate); - QUIC_BUG_IF(quic_bug_12513_1, max_bw_estimate_bytes_per_second < 0) - << max_bw_estimate_bytes_per_second; - QUIC_BUG_IF(quic_bug_10393_1, bw_estimate_bytes_per_second < 0) - << bw_estimate_bytes_per_second; + if (add_cached_network_parameters_to_address_token()) { + if (version().UsesTls()) { + if (version().HasIetfQuicFrames() && MaybeSendAddressToken()) { + bandwidth_estimate_sent_to_client_ = new_bandwidth_estimate; + } + } else { + absl::optional cached_network_params = + GenerateCachedNetworkParameters(); - CachedNetworkParameters cached_network_params; - cached_network_params.set_bandwidth_estimate_bytes_per_second( - bw_estimate_bytes_per_second); - cached_network_params.set_max_bandwidth_estimate_bytes_per_second( - max_bw_estimate_bytes_per_second); - cached_network_params.set_max_bandwidth_timestamp_seconds( - max_bandwidth_timestamp); - cached_network_params.set_min_rtt_ms( - sent_packet_manager.GetRttStats()->min_rtt().ToMilliseconds()); - cached_network_params.set_previous_connection_state( - bandwidth_recorder->EstimateRecordedDuringSlowStart() - ? CachedNetworkParameters::SLOW_START - : CachedNetworkParameters::CONGESTION_AVOIDANCE); - cached_network_params.set_timestamp( - connection()->clock()->WallNow().ToUNIXSeconds()); - if (!serving_region_.empty()) { - cached_network_params.set_serving_region(serving_region_); - } + if (cached_network_params.has_value()) { + bandwidth_estimate_sent_to_client_ = new_bandwidth_estimate; + QUIC_DVLOG(1) << "Server: sending new bandwidth estimate (KBytes/s): " + << bandwidth_estimate_sent_to_client_.ToKBytesPerSecond(); + + QUICHE_DCHECK_EQ( + BandwidthToCachedParameterBytesPerSecond( + bandwidth_estimate_sent_to_client_), + cached_network_params->bandwidth_estimate_bytes_per_second()); + + crypto_stream_->SendServerConfigUpdate(&cached_network_params.value()); + + connection()->OnSendConnectionState(*cached_network_params); + } + } + } else { + bandwidth_estimate_sent_to_client_ = new_bandwidth_estimate; + QUIC_DVLOG(1) << "Server: sending new bandwidth estimate (KBytes/s): " + << bandwidth_estimate_sent_to_client_.ToKBytesPerSecond(); + + // Include max bandwidth in the update. + QuicBandwidth max_bandwidth_estimate = + bandwidth_recorder->MaxBandwidthEstimate(); + int32_t max_bandwidth_timestamp = + bandwidth_recorder->MaxBandwidthTimestamp(); + + // Fill the proto before passing it to the crypto stream to send. + const int32_t bw_estimate_bytes_per_second = + BandwidthToCachedParameterBytesPerSecond( + bandwidth_estimate_sent_to_client_); + const int32_t max_bw_estimate_bytes_per_second = + BandwidthToCachedParameterBytesPerSecond(max_bandwidth_estimate); + QUIC_BUG_IF(quic_bug_12513_1, max_bw_estimate_bytes_per_second < 0) + << max_bw_estimate_bytes_per_second; + QUIC_BUG_IF(quic_bug_10393_1, bw_estimate_bytes_per_second < 0) + << bw_estimate_bytes_per_second; + + CachedNetworkParameters cached_network_params; + cached_network_params.set_bandwidth_estimate_bytes_per_second( + bw_estimate_bytes_per_second); + cached_network_params.set_max_bandwidth_estimate_bytes_per_second( + max_bw_estimate_bytes_per_second); + cached_network_params.set_max_bandwidth_timestamp_seconds( + max_bandwidth_timestamp); + cached_network_params.set_min_rtt_ms( + sent_packet_manager.GetRttStats()->min_rtt().ToMilliseconds()); + cached_network_params.set_previous_connection_state( + bandwidth_recorder->EstimateRecordedDuringSlowStart() + ? CachedNetworkParameters::SLOW_START + : CachedNetworkParameters::CONGESTION_AVOIDANCE); + cached_network_params.set_timestamp( + connection()->clock()->WallNow().ToUNIXSeconds()); + if (!serving_region_.empty()) { + cached_network_params.set_serving_region(serving_region_); + } - crypto_stream_->SendServerConfigUpdate(&cached_network_params); + crypto_stream_->SendServerConfigUpdate(&cached_network_params); - connection()->OnSendConnectionState(cached_network_params); + connection()->OnSendConnectionState(cached_network_params); + } last_scup_time_ = now; last_scup_packet_number_ = @@ -268,7 +315,7 @@ const QuicCryptoServerStreamBase* QuicServerSessionBase::GetCryptoStream() } int32_t QuicServerSessionBase::BandwidthToCachedParameterBytesPerSecond( - const QuicBandwidth& bandwidth) { + const QuicBandwidth& bandwidth) const { return static_cast(std::min( bandwidth.ToBytesPerSecond(), std::numeric_limits::max())); } @@ -288,4 +335,75 @@ void QuicServerSessionBase::SendSettingsToCryptoStream() { std::move(serialized_settings)); } +QuicSSLConfig QuicServerSessionBase::GetSSLConfig() const { + QUICHE_DCHECK(crypto_config_ && crypto_config_->proof_source()); + + QuicSSLConfig ssl_config = QuicSpdySession::GetSSLConfig(); + + ssl_config.disable_ticket_support = + GetQuicFlag(FLAGS_quic_disable_server_tls_resumption); + + if (!crypto_config_ || !crypto_config_->proof_source()) { + return ssl_config; + } + + absl::InlinedVector signature_algorithms = + crypto_config_->proof_source()->SupportedTlsSignatureAlgorithms(); + if (!signature_algorithms.empty()) { + ssl_config.signing_algorithm_prefs = std::move(signature_algorithms); + } + + return ssl_config; +} + +absl::optional +QuicServerSessionBase::GenerateCachedNetworkParameters() const { + QUICHE_DCHECK(add_cached_network_parameters_to_address_token()); + const QuicSentPacketManager& sent_packet_manager = + connection()->sent_packet_manager(); + const QuicSustainedBandwidthRecorder* bandwidth_recorder = + sent_packet_manager.SustainedBandwidthRecorder(); + + CachedNetworkParameters cached_network_params; + cached_network_params.set_timestamp( + connection()->clock()->WallNow().ToUNIXSeconds()); + + if (!sent_packet_manager.GetRttStats()->min_rtt().IsZero()) { + cached_network_params.set_min_rtt_ms( + sent_packet_manager.GetRttStats()->min_rtt().ToMilliseconds()); + } + + // Populate bandwidth estimates if any. + if (bandwidth_recorder != nullptr && bandwidth_recorder->HasEstimate()) { + const int32_t bw_estimate_bytes_per_second = + BandwidthToCachedParameterBytesPerSecond( + bandwidth_recorder->BandwidthEstimate()); + const int32_t max_bw_estimate_bytes_per_second = + BandwidthToCachedParameterBytesPerSecond( + bandwidth_recorder->MaxBandwidthEstimate()); + QUIC_BUG_IF(quic_bug_12513_1, max_bw_estimate_bytes_per_second < 0) + << max_bw_estimate_bytes_per_second; + QUIC_BUG_IF(quic_bug_10393_1, bw_estimate_bytes_per_second < 0) + << bw_estimate_bytes_per_second; + + cached_network_params.set_bandwidth_estimate_bytes_per_second( + bw_estimate_bytes_per_second); + cached_network_params.set_max_bandwidth_estimate_bytes_per_second( + max_bw_estimate_bytes_per_second); + cached_network_params.set_max_bandwidth_timestamp_seconds( + bandwidth_recorder->MaxBandwidthTimestamp()); + + cached_network_params.set_previous_connection_state( + bandwidth_recorder->EstimateRecordedDuringSlowStart() + ? CachedNetworkParameters::SLOW_START + : CachedNetworkParameters::CONGESTION_AVOIDANCE); + } + + if (!serving_region_.empty()) { + cached_network_params.set_serving_region(serving_region_); + } + + return cached_network_params; +} + } // namespace quic diff --git a/gquiche/quic/core/http/quic_server_session_base.h b/gquiche/quic/core/http/quic_server_session_base.h index 84c96e5a..0cba8382 100644 --- a/gquiche/quic/core/http/quic_server_session_base.h +++ b/gquiche/quic/core/http/quic_server_session_base.h @@ -68,12 +68,17 @@ class QUIC_EXPORT_PRIVATE QuicServerSessionBase : public QuicSpdySession { serving_region_ = serving_region; } + QuicSSLConfig GetSSLConfig() const override; + protected: // QuicSession methods(override them with return type of QuicSpdyStream*): QuicCryptoServerStreamBase* GetMutableCryptoStream() override; const QuicCryptoServerStreamBase* GetCryptoStream() const override; + absl::optional GenerateCachedNetworkParameters() + const override; + // If an outgoing stream can be created, return true. // Return false when connection is closed or forward secure encryption hasn't // established yet or number of server initiated streams already reaches the @@ -136,7 +141,7 @@ class QUIC_EXPORT_PRIVATE QuicServerSessionBase : public QuicSpdySession { // stored in CachedNetworkParameters. TODO(jokulik): This function // should go away once we fix http://b//27897982 int32_t BandwidthToCachedParameterBytesPerSecond( - const QuicBandwidth& bandwidth); + const QuicBandwidth& bandwidth) const; }; } // namespace quic diff --git a/gquiche/quic/core/http/quic_server_session_base_test.cc b/gquiche/quic/core/http/quic_server_session_base_test.cc index 7b3d7306..c6914ab5 100644 --- a/gquiche/quic/core/http/quic_server_session_base_test.cc +++ b/gquiche/quic/core/http/quic_server_session_base_test.cc @@ -10,7 +10,6 @@ #include #include "absl/memory/memory.h" -#include "absl/strings/string_view.h" #include "gquiche/quic/core/crypto/null_encrypter.h" #include "gquiche/quic/core/crypto/quic_crypto_server_config.h" #include "gquiche/quic/core/crypto/quic_random.h" @@ -18,6 +17,7 @@ #include "gquiche/quic/core/quic_connection.h" #include "gquiche/quic/core/quic_crypto_server_stream.h" #include "gquiche/quic/core/quic_crypto_server_stream_base.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/tls_server_handshaker.h" #include "gquiche/quic/platform/api/quic_expect_bug.h" @@ -51,6 +51,12 @@ namespace quic { namespace test { namespace { +// Data to be sent on a request stream. In Google QUIC, this is interpreted as +// DATA payload (there is no framing on request streams). In IETF QUIC, this is +// interpreted as HEADERS frame (type 0x1) with payload length 122 ('z'). Since +// no payload is included, QPACK decoder will not be invoked. +const char* const kStreamData = "\1z"; + class TestServerSession : public QuicServerSessionBase { public: TestServerSession(const QuicConfig& config, @@ -88,8 +94,8 @@ class TestServerSession : public QuicServerSessionBase { } QuicSpdyStream* CreateIncomingStream(PendingStream* pending) override { - QuicSpdyStream* stream = new QuicSimpleServerStream( - pending, this, BIDIRECTIONAL, quic_simple_server_backend_); + QuicSpdyStream* stream = + new QuicSimpleServerStream(pending, this, quic_simple_server_backend_); ActivateStream(absl::WrapUnique(stream)); return stream; } @@ -148,7 +154,7 @@ class QuicServerSessionBaseTest : public QuicTestWithParam { config_.SetInitialSessionFlowControlWindowToSend( kInitialSessionFlowControlWindowForTest); - ParsedQuicVersionVector supported_versions = SupportedVersions(GetParam()); + ParsedQuicVersionVector supported_versions = SupportedVersions(version()); connection_ = new StrictMock( &helper_, &alarm_factory_, Perspective::IS_SERVER, supported_versions); connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); @@ -166,23 +172,24 @@ class QuicServerSessionBaseTest : public QuicTestWithParam { QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( session_->config(), kMinimumFlowControlSendWindow); session_->OnConfigNegotiated(); - if (connection_->version().SupportsAntiAmplificationLimit()) { + if (version().SupportsAntiAmplificationLimit()) { QuicConnectionPeer::SetAddressValidated(connection_); } } QuicStreamId GetNthClientInitiatedBidirectionalId(int n) { - return GetNthClientInitiatedBidirectionalStreamId( - connection_->transport_version(), n); + return GetNthClientInitiatedBidirectionalStreamId(transport_version(), n); } QuicStreamId GetNthServerInitiatedUnidirectionalId(int n) { return quic::test::GetNthServerInitiatedUnidirectionalStreamId( - connection_->transport_version(), n); + transport_version(), n); } + ParsedQuicVersion version() const { return GetParam(); } + QuicTransportVersion transport_version() const { - return connection_->transport_version(); + return version().transport_version; } // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a @@ -242,10 +249,9 @@ INSTANTIATE_TEST_SUITE_P(Tests, ::testing::PrintToStringParamName()); TEST_P(QuicServerSessionBaseTest, CloseStreamDueToReset) { - // Open a stream, then reset it. - // Send two bytes of payload to open it. + // Send some data open a stream, then reset it. QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, - absl::string_view("HT")); + kStreamData); session_->OnStreamFrame(data1); EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); @@ -298,9 +304,9 @@ TEST_P(QuicServerSessionBaseTest, NeverOpenStreamDueToReset) { InjectStopSendingFrame(GetNthClientInitiatedBidirectionalId(0)); EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); - // Send two bytes of payload. + QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, - absl::string_view("HT")); + kStreamData); session_->OnStreamFrame(data1); // The stream should never be opened, now that the reset is received. @@ -309,11 +315,11 @@ TEST_P(QuicServerSessionBaseTest, NeverOpenStreamDueToReset) { } TEST_P(QuicServerSessionBaseTest, AcceptClosedStream) { - // Send (empty) compressed headers followed by two bytes of data. + // Send some data to open two streams. QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, - absl::string_view("\1\0\0\0\0\0\0\0HT")); + kStreamData); QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(1), false, 0, - absl::string_view("\3\0\0\0\0\0\0\0HT")); + kStreamData); session_->OnStreamFrame(frame1); session_->OnStreamFrame(frame2); EXPECT_EQ(2u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); @@ -341,9 +347,9 @@ TEST_P(QuicServerSessionBaseTest, AcceptClosedStream) { // past the reset point of stream 3. As it's a closed stream we just drop the // data on the floor, but accept the packet because it has data for stream 5. QuicStreamFrame frame3(GetNthClientInitiatedBidirectionalId(0), false, 2, - absl::string_view("TP")); + kStreamData); QuicStreamFrame frame4(GetNthClientInitiatedBidirectionalId(1), false, 2, - absl::string_view("TP")); + kStreamData); session_->OnStreamFrame(frame3); session_->OnStreamFrame(frame4); // The stream should never be opened, now that the reset is received. @@ -373,7 +379,7 @@ TEST_P(QuicServerSessionBaseTest, MaxOpenStreams) { for (size_t i = 0; i < kMaxStreamsForTest; ++i) { EXPECT_TRUE(QuicServerSessionBasePeer::GetOrCreateStream(session_.get(), stream_id)); - stream_id += QuicUtils::StreamIdDelta(connection_->transport_version()); + stream_id += QuicUtils::StreamIdDelta(transport_version()); } if (!VersionHasIetfQuicFrames(transport_version())) { @@ -382,11 +388,11 @@ TEST_P(QuicServerSessionBaseTest, MaxOpenStreams) { for (size_t i = 0; i < kMaxStreamsMinimumIncrement; ++i) { EXPECT_TRUE(QuicServerSessionBasePeer::GetOrCreateStream(session_.get(), stream_id)); - stream_id += QuicUtils::StreamIdDelta(connection_->transport_version()); + stream_id += QuicUtils::StreamIdDelta(transport_version()); } } // Now violate the server's internal stream limit. - stream_id += QuicUtils::StreamIdDelta(connection_->transport_version()); + stream_id += QuicUtils::StreamIdDelta(transport_version()); if (!VersionHasIetfQuicFrames(transport_version())) { // For non-version 99, QUIC responds to an attempt to exceed the stream @@ -418,8 +424,7 @@ TEST_P(QuicServerSessionBaseTest, MaxAvailableBidirectionalStreams) { session_.get(), GetNthClientInitiatedBidirectionalId(0))); // Establish available streams up to the server's limit. - QuicStreamId next_id = - QuicUtils::StreamIdDelta(connection_->transport_version()); + QuicStreamId next_id = QuicUtils::StreamIdDelta(transport_version()); const int kLimitingStreamId = GetNthClientInitiatedBidirectionalId(kAvailableStreamLimit + 1); if (!VersionHasIetfQuicFrames(transport_version())) { @@ -456,7 +461,7 @@ TEST_P(QuicServerSessionBaseTest, GetEvenIncomingError) { TEST_P(QuicServerSessionBaseTest, GetStreamDisconnected) { // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. - if (GetParam() != AllSupportedVersions()[0]) { + if (version() != AllSupportedVersions()[0]) { return; } @@ -498,13 +503,19 @@ class MockTlsServerHandshaker : public TlsServerHandshaker { MockTlsServerHandshaker& operator=(const MockTlsServerHandshaker&) = delete; ~MockTlsServerHandshaker() override {} - MOCK_METHOD(void, - SendServerConfigUpdate, - (const CachedNetworkParameters*), + MOCK_METHOD(void, SendServerConfigUpdate, (const CachedNetworkParameters*), (override)); + + MOCK_METHOD(std::string, GetAddressToken, (const CachedNetworkParameters*), + (const, override)); }; TEST_P(QuicServerSessionBaseTest, BandwidthEstimates) { + if (version().UsesTls() && !version().HasIetfQuicFrames()) { + // Skip the Txxx versions. + return; + } + // Test that bandwidth estimate updates are sent to the client, only when // bandwidth resumption is enabled, the bandwidth estimate has changed // sufficiently, enough time has passed, @@ -527,14 +538,13 @@ TEST_P(QuicServerSessionBaseTest, BandwidthEstimates) { if (!VersionUsesHttp3(transport_version())) { session_->UnregisterStreamPriority( - QuicUtils::GetHeadersStreamId(connection_->transport_version()), + QuicUtils::GetHeadersStreamId(transport_version()), /*is_static=*/true); } QuicServerSessionBasePeer::SetCryptoStream(session_.get(), nullptr); MockQuicCryptoServerStream* quic_crypto_stream = nullptr; MockTlsServerHandshaker* tls_server_stream = nullptr; - if (session_->connection()->version().handshake_protocol == - PROTOCOL_QUIC_CRYPTO) { + if (version().handshake_protocol == PROTOCOL_QUIC_CRYPTO) { quic_crypto_stream = new MockQuicCryptoServerStream( &crypto_config_, &compressed_certs_cache_, session_.get(), &stream_helper_); @@ -548,7 +558,7 @@ TEST_P(QuicServerSessionBaseTest, BandwidthEstimates) { } if (!VersionUsesHttp3(transport_version())) { session_->RegisterStreamPriority( - QuicUtils::GetHeadersStreamId(connection_->transport_version()), + QuicUtils::GetHeadersStreamId(transport_version()), /*is_static=*/true, spdy::SpdyStreamPrecedence(QuicStream::kDefaultPriority)); } @@ -571,11 +581,11 @@ TEST_P(QuicServerSessionBaseTest, BandwidthEstimates) { // Queue up some pending data. if (!VersionUsesHttp3(transport_version())) { session_->MarkConnectionLevelWriteBlocked( - QuicUtils::GetHeadersStreamId(connection_->transport_version())); + QuicUtils::GetHeadersStreamId(transport_version())); } else { session_->MarkConnectionLevelWriteBlocked( - QuicUtils::GetFirstUnidirectionalStreamId( - connection_->transport_version(), Perspective::IS_SERVER)); + QuicUtils::GetFirstUnidirectionalStreamId(transport_version(), + Perspective::IS_SERVER)); } EXPECT_TRUE(session_->HasDataToWrite()); @@ -633,22 +643,30 @@ TEST_P(QuicServerSessionBaseTest, BandwidthEstimates) { EXPECT_CALL(*quic_crypto_stream, SendServerConfigUpdate(EqualsProto(expected_network_params))) .Times(1); - } else { + } else if (!GetQuicReloadableFlag( + quic_add_cached_network_parameters_to_address_token2)) { EXPECT_CALL(*tls_server_stream, SendServerConfigUpdate(EqualsProto(expected_network_params))) .Times(1); + } else { + EXPECT_CALL(*tls_server_stream, + GetAddressToken(EqualsProto(expected_network_params))) + .WillOnce(testing::Return("Test address token")); } EXPECT_CALL(*connection_, OnSendConnectionState(_)).Times(1); session_->OnCongestionWindowChange(now); } TEST_P(QuicServerSessionBaseTest, BandwidthResumptionExperiment) { - if (GetParam().handshake_protocol == PROTOCOL_TLS1_3) { - // This test relies on resumption, which is not currently supported by the - // TLS handshake. - // TODO(nharper): Add support for resumption to the TLS handshake. - return; + if (version().UsesTls()) { + if (!version().HasIetfQuicFrames()) { + // Skip the Txxx versions. + return; + } + // Avoid a QUIC_BUG in QuicSession::OnConfigNegotiated. + connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); } + // Test that if a client provides a CachedNetworkParameters with the same // serving region as the current server, and which was made within an hour of // now, that this data is passed down to the send algorithm. @@ -720,20 +738,6 @@ TEST_P(QuicServerSessionBaseTest, NoBandwidthResumptionByDefault) { QuicServerSessionBasePeer::IsBandwidthResumptionEnabled(session_.get())); } -TEST_P(QuicServerSessionBaseTest, TurnOffServerPush) { - if (session_->version().UsesHttp3()) { - return; - } - - EXPECT_TRUE(session_->server_push_enabled()); - QuicTagVector copt; - copt.push_back(kQNSP); - QuicConfigPeer::SetReceivedConnectionOptions(session_->config(), copt); - connection_->SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); - session_->OnConfigNegotiated(); - EXPECT_FALSE(session_->server_push_enabled()); -} - // Tests which check the lifetime management of data members of // QuicCryptoServerStream objects when async GetProof is in use. class StreamMemberLifetimeTest : public QuicServerSessionBaseTest { @@ -762,7 +766,7 @@ INSTANTIATE_TEST_SUITE_P(StreamMemberLifetimeTests, // ProofSource::GetProof. Delay the completion of the operation until after the // stream has been destroyed, and verify that there are no memory bugs. TEST_P(StreamMemberLifetimeTest, Basic) { - if (GetParam().handshake_protocol == PROTOCOL_TLS1_3) { + if (version().handshake_protocol == PROTOCOL_TLS1_3) { // This test depends on the QUIC crypto protocol, so it is disabled for the // TLS handshake. // TODO(nharper): Fix this test so it doesn't rely on QUIC crypto. @@ -771,9 +775,9 @@ TEST_P(StreamMemberLifetimeTest, Basic) { const QuicClock* clock = helper_.GetClock(); CryptoHandshakeMessage chlo = crypto_test_utils::GenerateDefaultInchoateCHLO( - clock, GetParam().transport_version, &crypto_config_); + clock, transport_version(), &crypto_config_); chlo.SetVector(kCOPT, QuicTagVector{kREJ}); - std::vector packet_version_list = {GetParam()}; + std::vector packet_version_list = {version()}; std::unique_ptr packet(ConstructEncryptedPacket( TestConnectionId(1), EmptyQuicConnectionId(), true, false, 1, std::string(chlo.GetSerialized().AsStringPiece()), CONNECTION_ID_PRESENT, diff --git a/gquiche/quic/core/http/quic_spdy_client_session.cc b/gquiche/quic/core/http/quic_spdy_client_session.cc index 80916d09..8937db2c 100644 --- a/gquiche/quic/core/http/quic_spdy_client_session.cc +++ b/gquiche/quic/core/http/quic_spdy_client_session.cc @@ -126,6 +126,10 @@ int QuicSpdyClientSession::GetNumSentClientHellos() const { return crypto_stream_->num_sent_client_hellos(); } +bool QuicSpdyClientSession::IsResumption() const { + return crypto_stream_->IsResumption(); +} + bool QuicSpdyClientSession::EarlyDataAccepted() const { return crypto_stream_->EarlyDataAccepted(); } @@ -181,8 +185,7 @@ bool QuicSpdyClientSession::ShouldCreateIncomingStream(QuicStreamId id) { QuicSpdyStream* QuicSpdyClientSession::CreateIncomingStream( PendingStream* pending) { - QuicSpdyStream* stream = - new QuicSpdyClientStream(pending, this, READ_UNIDIRECTIONAL); + QuicSpdyStream* stream = new QuicSpdyClientStream(pending, this); ActivateStream(absl::WrapUnique(stream)); return stream; } diff --git a/gquiche/quic/core/http/quic_spdy_client_session.h b/gquiche/quic/core/http/quic_spdy_client_session.h index f1fbf735..99bb2ae6 100644 --- a/gquiche/quic/core/http/quic_spdy_client_session.h +++ b/gquiche/quic/core/http/quic_spdy_client_session.h @@ -61,6 +61,10 @@ class QUIC_EXPORT_PRIVATE QuicSpdyClientSession // than the number of round-trips needed for the handshake. int GetNumSentClientHellos() const; + // Return true if the handshake performed is a TLS resumption. + // Always return false for QUIC Crypto. + bool IsResumption() const; + // Returns true if early data (0-RTT data) was sent and the server accepted // it. bool EarlyDataAccepted() const; @@ -103,7 +107,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdyClientSession // CreateOutgoingBidirectionalStream. virtual std::unique_ptr CreateClientStream(); - const QuicServerId& server_id() { return server_id_; } + const QuicServerId& server_id() const { return server_id_; } QuicCryptoClientConfig* crypto_config() { return crypto_config_; } private: diff --git a/gquiche/quic/core/http/quic_spdy_client_session_test.cc b/gquiche/quic/core/http/quic_spdy_client_session_test.cc index 885238b3..19618d86 100644 --- a/gquiche/quic/core/http/quic_spdy_client_session_test.cc +++ b/gquiche/quic/core/http/quic_spdy_client_session_test.cc @@ -278,10 +278,11 @@ TEST_P(QuicSpdyClientSessionTest, NoEncryptionAfterInitialEncryption) { EXPECT_TRUE(session_->CreateOutgoingBidirectionalStream() == nullptr); // Verify that no data may be send on existing streams. char data[] = "hello world"; - EXPECT_QUIC_BUG( + QuicConsumedData consumed = session_->WritevData(stream->id(), ABSL_ARRAYSIZE(data), 0, NO_FIN, - NOT_RETRANSMISSION, ENCRYPTION_INITIAL), - "Client: Try to send data of stream"); + NOT_RETRANSMISSION, ENCRYPTION_INITIAL); + EXPECT_EQ(0u, consumed.bytes_consumed); + EXPECT_FALSE(consumed.fin_consumed); } TEST_P(QuicSpdyClientSessionTest, MaxNumStreamsWithNoFinOrRst) { @@ -981,10 +982,10 @@ TEST_P(QuicSpdyClientSessionTest, OnSettingsFrame) { ApplicationState expected(std::begin(application_state), std::end(application_state)); session_->OnSettingsFrame(settings); - EXPECT_EQ(expected, - *client_session_cache_ - ->Lookup(QuicServerId(kServerHostname, kPort, false), nullptr) - ->application_state); + EXPECT_EQ(expected, *client_session_cache_ + ->Lookup(QuicServerId(kServerHostname, kPort, false), + session_->GetClock()->WallNow(), nullptr) + ->application_state); } TEST_P(QuicSpdyClientSessionTest, IetfZeroRttSetup) { diff --git a/gquiche/quic/core/http/quic_spdy_client_stream.cc b/gquiche/quic/core/http/quic_spdy_client_stream.cc index a9d1c45c..c5961c23 100644 --- a/gquiche/quic/core/http/quic_spdy_client_stream.cc +++ b/gquiche/quic/core/http/quic_spdy_client_stream.cc @@ -13,6 +13,7 @@ #include "gquiche/quic/core/http/web_transport_http3.h" #include "gquiche/quic/core/quic_alarm.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/common/quiche_text_utils.h" #include "gquiche/spdy/core/spdy_protocol.h" using spdy::SpdyHeaderBlock; @@ -31,9 +32,8 @@ QuicSpdyClientStream::QuicSpdyClientStream(QuicStreamId id, has_preliminary_headers_(false) {} QuicSpdyClientStream::QuicSpdyClientStream(PendingStream* pending, - QuicSpdyClientSession* session, - StreamType type) - : QuicSpdyStream(pending, session, type), + QuicSpdyClientSession* session) + : QuicSpdyStream(pending, session), content_length_(-1), response_code_(0), header_bytes_read_(0), @@ -51,6 +51,12 @@ void QuicSpdyClientStream::OnInitialHeadersComplete( QUICHE_DCHECK(headers_decompressed()); header_bytes_read_ += frame_len; + if (rst_sent()) { + // QuicSpdyStream::OnInitialHeadersComplete already rejected invalid + // response header. + return; + } + if (!SpdyUtils::CopyAndValidateHeaders(header_list, &content_length_, &response_headers_)) { QUIC_DLOG(ERROR) << "Failed to parse header list: " @@ -62,8 +68,13 @@ void QuicSpdyClientStream::OnInitialHeadersComplete( if (web_transport() != nullptr) { web_transport()->HeadersReceived(response_headers_); if (!web_transport()->ready()) { - // Rejected due to status not being 200, or other reason. - WriteOrBufferData("", /*fin=*/true, nullptr); + // The request was rejected by WebTransport, typically due to not having a + // 2xx status. The reason we're using Reset() here rather than closing + // cleanly is that even if the server attempts to send us any form of body + // with a 4xx request, we've already set up the capsule parser, and we + // don't have any way to process anything from the response body in + // question. + Reset(QUIC_STREAM_CANCELLED); return; } } @@ -184,4 +195,25 @@ size_t QuicSpdyClientStream::SendRequest(SpdyHeaderBlock headers, return bytes_sent; } +bool QuicSpdyClientStream::AreHeadersValid( + const QuicHeaderList& header_list) const { + if (!GetQuicReloadableFlag(quic_verify_request_headers_2)) { + return true; + } + if (!QuicSpdyStream::AreHeadersValid(header_list)) { + return false; + } + // Verify the presence of :status header. + bool saw_status = false; + for (const std::pair& pair : header_list) { + if (pair.first == ":status") { + saw_status = true; + } else if (absl::StrContains(pair.first, ":")) { + QUIC_DLOG(ERROR) << "Unexpected ':' in header " << pair.first << "."; + return false; + } + } + return saw_status; +} + } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_client_stream.h b/gquiche/quic/core/http/quic_spdy_client_stream.h index f4dd7e68..d9a10357 100644 --- a/gquiche/quic/core/http/quic_spdy_client_stream.h +++ b/gquiche/quic/core/http/quic_spdy_client_stream.h @@ -25,8 +25,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdyClientStream : public QuicSpdyStream { QuicSpdyClientSession* session, StreamType type); QuicSpdyClientStream(PendingStream* pending, - QuicSpdyClientSession* spdy_session, - StreamType type); + QuicSpdyClientSession* spdy_session); QuicSpdyClientStream(const QuicSpdyClientStream&) = delete; QuicSpdyClientStream& operator=(const QuicSpdyClientStream&) = delete; ~QuicSpdyClientStream() override; @@ -75,6 +74,9 @@ class QUIC_EXPORT_PRIVATE QuicSpdyClientStream : public QuicSpdyStream { // of client-side streams should be able to set the priority. using QuicSpdyStream::SetPriority; + protected: + bool AreHeadersValid(const QuicHeaderList& header_list) const override; + private: // The parsed headers received from the server. spdy::SpdyHeaderBlock response_headers_; diff --git a/gquiche/quic/core/http/quic_spdy_client_stream_test.cc b/gquiche/quic/core/http/quic_spdy_client_stream_test.cc index 9e6f9d70..f5239e19 100644 --- a/gquiche/quic/core/http/quic_spdy_client_stream_test.cc +++ b/gquiche/quic/core/http/quic_spdy_client_stream_test.cc @@ -12,6 +12,7 @@ #include "gquiche/quic/core/crypto/null_encrypter.h" #include "gquiche/quic/core/http/quic_spdy_client_session.h" #include "gquiche/quic/core/http/spdy_utils.h" +#include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_socket_address.h" @@ -127,16 +128,40 @@ TEST_P(QuicSpdyClientStreamTest, TestReceivingIllegalResponseStatusCode) { IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); } +TEST_P(QuicSpdyClientStreamTest, InvalidResponseHeader) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + auto headers = AsHeaderList(std::vector>{ + {":status", "200"}, {":path", "/foo"}}); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD)); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_THAT(stream_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + +TEST_P(QuicSpdyClientStreamTest, MissingStatusCode) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + auto headers = AsHeaderList( + std::vector>{{"key", "value"}}); + EXPECT_CALL(*connection_, + OnStreamReset(stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD)); + stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), + headers); + EXPECT_THAT(stream_->stream_error(), + IsStreamError(QUIC_BAD_APPLICATION_PAYLOAD)); +} + TEST_P(QuicSpdyClientStreamTest, TestFraming) { auto headers = AsHeaderList(headers_); stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), headers); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); std::string data = VersionUsesHttp3(connection_->transport_version()) - ? header + body_ + ? absl::StrCat(header.AsStringView(), body_) : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); @@ -160,12 +185,11 @@ TEST_P(QuicSpdyClientStreamTest, Test100ContinueBeforeSuccessful) { headers = AsHeaderList(headers_); stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), headers); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - std::string data = - connection_->version().UsesHttp3() ? header + body_ : body_; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), body_) + : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); // Make sure the 200 response got parsed correctly. @@ -190,12 +214,11 @@ TEST_P(QuicSpdyClientStreamTest, TestUnknownInformationalBeforeSuccessful) { headers = AsHeaderList(headers_); stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), headers); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - std::string data = - connection_->version().UsesHttp3() ? header + body_ : body_; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); + std::string data = VersionUsesHttp3(connection_->transport_version()) + ? absl::StrCat(header.AsStringView(), body_) + : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); // Make sure the 200 response got parsed correctly. @@ -222,12 +245,10 @@ TEST_P(QuicSpdyClientStreamTest, TestFramingOnePacket) { auto headers = AsHeaderList(headers_); stream_->OnStreamHeaderList(false, headers.uncompressed_header_bytes(), headers); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); std::string data = VersionUsesHttp3(connection_->transport_version()) - ? header + body_ + ? absl::StrCat(header.AsStringView(), body_) : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); @@ -247,12 +268,10 @@ TEST_P(QuicSpdyClientStreamTest, EXPECT_THAT(stream_->stream_error(), IsQuicStreamNoError()); EXPECT_EQ("200", stream_->response_headers().find(":status")->second); EXPECT_EQ(200, stream_->response_code()); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(large_body.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + large_body.length(), SimpleBufferAllocator::Get()); std::string data = VersionUsesHttp3(connection_->transport_version()) - ? header + large_body + ? absl::StrCat(header.AsStringView(), large_body) : large_body; EXPECT_CALL(session_, WriteControlFrame(_, _)); EXPECT_CALL(*connection_, @@ -290,12 +309,10 @@ TEST_P(QuicSpdyClientStreamTest, ReceivingTrailers) { // Now send the body, which should close the stream as the FIN has been // received, as well as all data. - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); std::string data = VersionUsesHttp3(connection_->transport_version()) - ? header + body_ + ? absl::StrCat(header.AsStringView(), body_) : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); diff --git a/gquiche/quic/core/http/quic_spdy_server_stream_base.cc b/gquiche/quic/core/http/quic_spdy_server_stream_base.cc index ea3029ba..dc31390b 100644 --- a/gquiche/quic/core/http/quic_spdy_server_stream_base.cc +++ b/gquiche/quic/core/http/quic_spdy_server_stream_base.cc @@ -4,9 +4,13 @@ #include "gquiche/quic/core/http/quic_spdy_server_stream_base.h" +#include "absl/strings/string_view.h" +#include "gquiche/quic/core/http/quic_spdy_session.h" #include "gquiche/quic/core/quic_error_codes.h" -#include "gquiche/quic/core/quic_session.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" +#include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { @@ -16,9 +20,8 @@ QuicSpdyServerStreamBase::QuicSpdyServerStreamBase(QuicStreamId id, : QuicSpdyStream(id, session, type) {} QuicSpdyServerStreamBase::QuicSpdyServerStreamBase(PendingStream* pending, - QuicSpdySession* session, - StreamType type) - : QuicSpdyStream(pending, session, type) {} + QuicSpdySession* session) + : QuicSpdyStream(pending, session) {} void QuicSpdyServerStreamBase::CloseWriteSide() { if (!fin_received() && !rst_received() && sequencer()->ignore_read_data() && @@ -45,4 +48,87 @@ void QuicSpdyServerStreamBase::StopReading() { QuicSpdyStream::StopReading(); } +bool QuicSpdyServerStreamBase::AreHeadersValid( + const QuicHeaderList& header_list) const { + if (!GetQuicReloadableFlag(quic_verify_request_headers_2)) { + return true; + } + QUIC_RELOADABLE_FLAG_COUNT_N(quic_verify_request_headers_2, 2, 3); + if (!QuicSpdyStream::AreHeadersValid(header_list)) { + return false; + } + + bool saw_connect = false; + bool saw_protocol = false; + bool saw_path = false; + bool saw_scheme = false; + bool saw_method = false; + bool saw_authority = false; + bool is_extended_connect = false; + // Check if it is missing any required headers and if there is any disallowed + // ones. + for (const std::pair& pair : header_list) { + if (pair.first == ":method") { + saw_method = true; + if (pair.second == "CONNECT") { + saw_connect = true; + if (saw_protocol) { + is_extended_connect = true; + } + } + } else if (pair.first == ":protocol") { + saw_protocol = true; + if (saw_connect) { + is_extended_connect = true; + } + } else if (pair.first == ":scheme") { + saw_scheme = true; + } else if (pair.first == ":path") { + saw_path = true; + } else if (pair.first == ":authority") { + saw_authority = true; + } else if (absl::StrContains(pair.first, ":")) { + QUIC_DLOG(ERROR) << "Unexpected ':' in header " << pair.first << "."; + return false; + } + if (is_extended_connect) { + if (!spdy_session()->allow_extended_connect()) { + QUIC_DLOG(ERROR) + << "Received extended-CONNECT request while it is disabled."; + return false; + } + } else if (saw_method && !saw_connect) { + if (saw_protocol) { + QUIC_DLOG(ERROR) << "Receive non-CONNECT request with :protocol."; + return false; + } + } + } + + if (is_extended_connect) { + if (saw_scheme && saw_path && saw_authority) { + // Saw all the required pseudo headers. + return true; + } + QUIC_DLOG(ERROR) << "Missing required pseudo headers for extended-CONNECT."; + return false; + } + // This is a vanilla CONNECT or non-CONNECT request. + if (saw_connect) { + // Check vanilla CONNECT. + if (saw_path || saw_scheme) { + QUIC_DLOG(ERROR) + << "Received invalid CONNECT request with disallowed pseudo header."; + return false; + } + return true; + } + // Check non-CONNECT request. + if (saw_method && saw_authority && saw_path && saw_scheme) { + return true; + } + QUIC_LOG(ERROR) << "Missing required pseudo headers."; + return false; +} + } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_server_stream_base.h b/gquiche/quic/core/http/quic_spdy_server_stream_base.h index ed742303..9bb4d1af 100644 --- a/gquiche/quic/core/http/quic_spdy_server_stream_base.h +++ b/gquiche/quic/core/http/quic_spdy_server_stream_base.h @@ -14,9 +14,7 @@ class QUIC_NO_EXPORT QuicSpdyServerStreamBase : public QuicSpdyStream { QuicSpdyServerStreamBase(QuicStreamId id, QuicSpdySession* session, StreamType type); - QuicSpdyServerStreamBase(PendingStream* pending, - QuicSpdySession* session, - StreamType type); + QuicSpdyServerStreamBase(PendingStream* pending, QuicSpdySession* session); QuicSpdyServerStreamBase(const QuicSpdyServerStreamBase&) = delete; QuicSpdyServerStreamBase& operator=(const QuicSpdyServerStreamBase&) = delete; @@ -24,6 +22,9 @@ class QUIC_NO_EXPORT QuicSpdyServerStreamBase : public QuicSpdyStream { // when the stream has not received all the data. void CloseWriteSide() override; void StopReading() override; + + protected: + bool AreHeadersValid(const QuicHeaderList& header_list) const override; }; } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_server_stream_base_test.cc b/gquiche/quic/core/http/quic_spdy_server_stream_base_test.cc index 3cd6fd68..2a21e51e 100644 --- a/gquiche/quic/core/http/quic_spdy_server_stream_base_test.cc +++ b/gquiche/quic/core/http/quic_spdy_server_stream_base_test.cc @@ -6,7 +6,9 @@ #include "absl/memory/memory.h" #include "gquiche/quic/core/crypto/null_encrypter.h" +#include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_test.h" +#include "gquiche/quic/test_tools/qpack/qpack_encoder_test_utils.h" #include "gquiche/quic/test_tools/quic_spdy_session_peer.h" #include "gquiche/quic/test_tools/quic_stream_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" @@ -19,8 +21,7 @@ namespace { class TestQuicSpdyServerStream : public QuicSpdyServerStreamBase { public: - TestQuicSpdyServerStream(QuicStreamId id, - QuicSpdySession* session, + TestQuicSpdyServerStream(QuicStreamId id, QuicSpdySession* session, StreamType type) : QuicSpdyServerStreamBase(id, session, type) {} @@ -30,8 +31,7 @@ class TestQuicSpdyServerStream : public QuicSpdyServerStreamBase { class QuicSpdyServerStreamBaseTest : public QuicTest { protected: QuicSpdyServerStreamBaseTest() - : session_(new MockQuicConnection(&helper_, - &alarm_factory_, + : session_(new MockQuicConnection(&helper_, &alarm_factory_, Perspective::IS_SERVER)) { session_.Initialize(); session_.connection()->SetEncrypter( @@ -56,10 +56,15 @@ TEST_F(QuicSpdyServerStreamBaseTest, stream_->StopReading(); if (session_.version().UsesHttp3()) { - EXPECT_CALL(session_, MaybeSendStopSendingFrame(_, QUIC_STREAM_NO_ERROR)) + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))) .Times(1); } else { - EXPECT_CALL(session_, MaybeSendRstStreamFrame(_, QUIC_STREAM_NO_ERROR, _)) + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_NO_ERROR), _)) .Times(1); } QuicStreamPeer::SetFinSent(stream_); @@ -73,9 +78,10 @@ TEST_F(QuicSpdyServerStreamBaseTest, EXPECT_CALL(session_, MaybeSendRstStreamFrame( _, - VersionHasIetfQuicFrames(session_.transport_version()) - ? QUIC_STREAM_CANCELLED - : QUIC_RST_ACKNOWLEDGEMENT, + QuicResetStreamError::FromInternal( + VersionHasIetfQuicFrames(session_.transport_version()) + ? QUIC_STREAM_CANCELLED + : QUIC_RST_ACKNOWLEDGEMENT), _)) .Times(1); QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), @@ -93,6 +99,241 @@ TEST_F(QuicSpdyServerStreamBaseTest, EXPECT_TRUE(stream_->write_side_closed()); } +TEST_F(QuicSpdyServerStreamBaseTest, AllowExtendedConnect) { + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_EQ(GetQuicReloadableFlag(quic_verify_request_headers_2) && + GetQuicReloadableFlag(quic_act_upon_invalid_header) && + !session_.allow_extended_connect(), + stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, AllowExtendedConnectProtocolFirst) { + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_EQ(GetQuicReloadableFlag(quic_verify_request_headers_2) && + GetQuicReloadableFlag(quic_act_upon_invalid_header) && + !session_.allow_extended_connect(), + stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidExtendedConnect) { + if (!session_.version().UsesHttp3()) { + return; + } + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, VanillaConnectAllowed) { + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeaderBlockEnd(128, 128); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_FALSE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidVanillaConnect) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "CONNECT"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidNonConnectWithProtocol) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "GET"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeader(":protocol", "webtransport"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutScheme) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :scheme should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":method", "GET"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutAuthority) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :authority should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":method", "GET"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutMethod) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :method should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":path", "/path"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestWithoutPath) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :path should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":method", "POST"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, InvalidRequestHeader) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + // A request without :path should be rejected. + QuicHeaderList header_list; + header_list.OnHeaderBlockStart(); + header_list.OnHeader(":authority", "www.google.com:4433"); + header_list.OnHeader(":scheme", "http"); + header_list.OnHeader(":method", "POST"); + header_list.OnHeader("invalid:header", "value"); + header_list.OnHeaderBlockEnd(128, 128); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamHeaderList(/*fin=*/false, 0, header_list); + EXPECT_TRUE(stream_->rst_sent()); +} + +TEST_F(QuicSpdyServerStreamBaseTest, EmptyHeaders) { + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + spdy::SpdyHeaderBlock empty_header; + quic::test::NoopQpackStreamSenderDelegate encoder_stream_sender_delegate; + quic::test::NoopDecoderStreamErrorDelegate decoder_stream_error_delegate; + auto qpack_encoder = + std::make_unique(&decoder_stream_error_delegate); + qpack_encoder->set_qpack_stream_sender_delegate( + &encoder_stream_sender_delegate); + std::string payload = + qpack_encoder->EncodeHeaderList(stream_->id(), empty_header, nullptr); + std::unique_ptr headers_buffer; + quic::QuicByteCount headers_frame_header_length = + quic::HttpEncoder::SerializeHeadersFrameHeader(payload.length(), + &headers_buffer); + absl::string_view headers_frame_header(headers_buffer.get(), + headers_frame_header_length); + + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), + _)); + stream_->OnStreamFrame(QuicStreamFrame( + stream_->id(), true, 0, absl::StrCat(headers_frame_header, payload))); + EXPECT_TRUE(stream_->rst_sent()); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_session.cc b/gquiche/quic/core/http/quic_spdy_session.cc index 97a2b7f2..107ff132 100644 --- a/gquiche/quic/core/http/quic_spdy_session.cc +++ b/gquiche/quic/core/http/quic_spdy_session.cc @@ -30,12 +30,9 @@ #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_stack_trace.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/spdy/core/http2_frame_decoder_adapter.h" using http2::Http2DecoderAdapter; -using spdy::HpackEntry; -using spdy::HpackHeaderTable; using spdy::Http2WeightToSpdy3Priority; using spdy::Spdy3PriorityToHttp2Weight; using spdy::SpdyErrorCode; @@ -73,10 +70,6 @@ class AlpsFrameDecoder : public HttpDecoder::Visitor { // HttpDecoder::Visitor implementation. void OnError(HttpDecoder* /*decoder*/) override {} - bool OnCancelPushFrame(const CancelPushFrame& /*frame*/) override { - error_detail_ = "CANCEL_PUSH frame forbidden"; - return false; - } bool OnMaxPushIdFrame(const MaxPushIdFrame& /*frame*/) override { error_detail_ = "MAX_PUSH_ID frame forbidden"; return false; @@ -125,26 +118,6 @@ class AlpsFrameDecoder : public HttpDecoder::Visitor { QUICHE_NOTREACHED(); return false; } - bool OnPushPromiseFrameStart(QuicByteCount /*header_length*/) override { - error_detail_ = "PUSH_PROMISE frame forbidden"; - return false; - } - bool OnPushPromiseFramePushId( - PushId /*push_id*/, - QuicByteCount - /*push_id_length*/, - QuicByteCount /*header_block_length*/) override { - QUICHE_NOTREACHED(); - return false; - } - bool OnPushPromiseFramePayload(absl::string_view /*payload*/) override { - QUICHE_NOTREACHED(); - return false; - } - bool OnPushPromiseFrameEnd() override { - QUICHE_NOTREACHED(); - return false; - } bool OnPriorityUpdateFrameStart(QuicByteCount /*header_length*/) override { error_detail_ = "PRIORITY_UPDATE frame forbidden"; return false; @@ -220,8 +193,7 @@ class QuicSpdySession::SpdyFramerVisitor header_list_.Clear(); } - void OnStreamFrameData(SpdyStreamId /*stream_id*/, - const char* /*data*/, + void OnStreamFrameData(SpdyStreamId /*stream_id*/, const char* /*data*/, size_t /*len*/) override { QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); CloseConnection("SPDY DATA frame received.", @@ -309,8 +281,7 @@ class QuicSpdySession::SpdyFramerVisitor code); } - void OnDataFrameHeader(SpdyStreamId /*stream_id*/, - size_t /*length*/, + void OnDataFrameHeader(SpdyStreamId /*stream_id*/, size_t /*length*/, bool /*fin*/) override { QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); CloseConnection("SPDY DATA frame received.", @@ -343,13 +314,9 @@ class QuicSpdySession::SpdyFramerVisitor QUIC_INVALID_HEADERS_STREAM_DATA); } - void OnHeaders(SpdyStreamId stream_id, - bool has_priority, - int weight, - SpdyStreamId /* parent_stream_id */, - bool /* exclusive */, - bool fin, - bool /*end*/) override { + void OnHeaders(SpdyStreamId stream_id, bool has_priority, int weight, + SpdyStreamId /* parent_stream_id */, bool /* exclusive */, + bool fin, bool /*end*/) override { if (!session_->IsConnected()) { return; } @@ -377,8 +344,7 @@ class QuicSpdySession::SpdyFramerVisitor QUIC_INVALID_HEADERS_STREAM_DATA); } - void OnPushPromise(SpdyStreamId stream_id, - SpdyStreamId promised_stream_id, + void OnPushPromise(SpdyStreamId stream_id, SpdyStreamId promised_stream_id, bool /*end*/) override { QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); if (session_->perspective() != Perspective::IS_CLIENT) { @@ -394,10 +360,8 @@ class QuicSpdySession::SpdyFramerVisitor void OnContinuation(SpdyStreamId /*stream_id*/, bool /*end*/) override {} - void OnPriority(SpdyStreamId stream_id, - SpdyStreamId /* parent_id */, - int weight, - bool /* exclusive */) override { + void OnPriority(SpdyStreamId stream_id, SpdyStreamId /* parent_id */, + int weight, bool /* exclusive */) override { QUICHE_DCHECK(!VersionUsesHttp3(session_->transport_version())); if (!session_->IsConnected()) { return; @@ -420,10 +384,8 @@ class QuicSpdySession::SpdyFramerVisitor } // SpdyFramerDebugVisitorInterface implementation - void OnSendCompressedFrame(SpdyStreamId /*stream_id*/, - SpdyFrameType /*type*/, - size_t payload_len, - size_t frame_len) override { + void OnSendCompressedFrame(SpdyStreamId /*stream_id*/, SpdyFrameType /*type*/, + size_t payload_len, size_t frame_len) override { if (payload_len == 0) { QUIC_BUG(quic_bug_10360_1) << "Zero payload length."; return; @@ -462,14 +424,9 @@ Http3DebugVisitor::~Http3DebugVisitor() {} // Expected unidirectional static streams Requirement can be found at // https://tools.ietf.org/html/draft-ietf-quic-http-22#section-6.2. QuicSpdySession::QuicSpdySession( - QuicConnection* connection, - QuicSession::Visitor* visitor, - const QuicConfig& config, - const ParsedQuicVersionVector& supported_versions) - : QuicSession(connection, - visitor, - config, - supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions) + : QuicSession(connection, visitor, config, supported_versions, /*num_expected_unidirectional_static_streams = */ VersionUsesHttp3(connection->transport_version()) ? static_cast( @@ -497,10 +454,10 @@ QuicSpdySession::QuicSpdySession( spdy_framer_visitor_(new SpdyFramerVisitor(this)), debug_visitor_(nullptr), destruction_indicator_(123456789), - server_push_enabled_(true), - next_available_datagram_flow_id_(perspective() == Perspective::IS_SERVER - ? kFirstDatagramFlowIdServer - : kFirstDatagramFlowIdClient) { + allow_extended_connect_( + GetQuicReloadableFlag(quic_verify_request_headers_2) && + perspective() == Perspective::IS_SERVER && + VersionUsesHttp3(transport_version())) { h2_deframer_.set_visitor(spdy_framer_visitor_.get()); h2_deframer_.set_debug_visitor(spdy_framer_visitor_.get()); spdy_framer_.set_debug_visitor(spdy_framer_visitor_.get()); @@ -554,13 +511,25 @@ void QuicSpdySession::FillSettingsFrame() { qpack_maximum_blocked_streams_; settings_.values[SETTINGS_MAX_FIELD_SECTION_SIZE] = max_inbound_header_list_size_; - if (ShouldNegotiateHttp3Datagram() && version().UsesHttp3()) { - QUIC_RELOADABLE_FLAG_COUNT(quic_h3_datagram); - settings_.values[SETTINGS_H3_DATAGRAM] = 1; + if (version().UsesHttp3()) { + HttpDatagramSupport local_http_datagram_support = + LocalHttpDatagramSupport(); + if (local_http_datagram_support == HttpDatagramSupport::kDraft00 || + local_http_datagram_support == HttpDatagramSupport::kDraft00And04) { + settings_.values[SETTINGS_H3_DATAGRAM_DRAFT00] = 1; + } + if (local_http_datagram_support == HttpDatagramSupport::kDraft04 || + local_http_datagram_support == HttpDatagramSupport::kDraft00And04) { + settings_.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + } } if (WillNegotiateWebTransport()) { settings_.values[SETTINGS_WEBTRANS_DRAFT00] = 1; } + if (allow_extended_connect()) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_verify_request_headers_2, 1, 3); + settings_.values[SETTINGS_ENABLE_CONNECT_PROTOCOL] = 1; + } } void QuicSpdySession::OnDecoderStreamError(QuicErrorCode error_code, @@ -580,8 +549,7 @@ void QuicSpdySession::OnEncoderStreamError(QuicErrorCode error_code, } void QuicSpdySession::OnStreamHeadersPriority( - QuicStreamId stream_id, - const spdy::SpdyStreamPrecedence& precedence) { + QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence) { QuicSpdyStream* stream = GetOrCreateSpdyDataStream(stream_id); if (!stream) { // It's quite possible to receive headers after a stream has been reset. @@ -590,8 +558,7 @@ void QuicSpdySession::OnStreamHeadersPriority( stream->OnStreamHeadersPriority(precedence); } -void QuicSpdySession::OnStreamHeaderList(QuicStreamId stream_id, - bool fin, +void QuicSpdySession::OnStreamHeaderList(QuicStreamId stream_id, bool fin, size_t frame_len, const QuicHeaderList& header_list) { if (IsStaticStream(stream_id)) { @@ -630,8 +597,7 @@ void QuicSpdySession::OnStreamHeaderList(QuicStreamId stream_id, } void QuicSpdySession::OnPriorityFrame( - QuicStreamId stream_id, - const spdy::SpdyStreamPrecedence& precedence) { + QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence) { QuicSpdyStream* stream = GetOrCreateSpdyDataStream(stream_id); if (!stream) { // It's quite possible to receive a PRIORITY frame after a stream has been @@ -708,9 +674,7 @@ size_t QuicSpdySession::ProcessHeaderData(const struct iovec& iov) { } size_t QuicSpdySession::WriteHeadersOnHeadersStream( - QuicStreamId id, - SpdyHeaderBlock headers, - bool fin, + QuicStreamId id, SpdyHeaderBlock headers, bool fin, const spdy::SpdyStreamPrecedence& precedence, QuicReferenceCountedPointer ack_listener) { QUICHE_DCHECK(!VersionUsesHttp3(transport_version())); @@ -723,8 +687,7 @@ size_t QuicSpdySession::WriteHeadersOnHeadersStream( } size_t QuicSpdySession::WritePriority(QuicStreamId id, - QuicStreamId parent_stream_id, - int weight, + QuicStreamId parent_stream_id, int weight, bool exclusive) { QUICHE_DCHECK(!VersionUsesHttp3(transport_version())); SpdyPriorityIR priority_frame(id, parent_stream_id, weight, exclusive); @@ -796,17 +759,13 @@ bool QuicSpdySession::OnStreamsBlockedFrame( void QuicSpdySession::SendHttp3GoAway(QuicErrorCode error_code, const std::string& reason) { - QUICHE_DCHECK_EQ(perspective(), Perspective::IS_SERVER); QUICHE_DCHECK(VersionUsesHttp3(transport_version())); - if (GetQuicReloadableFlag(quic_encrypted_goaway)) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_encrypted_goaway, 2, 2); - if (!IsEncryptionEstablished()) { - QUIC_CODE_COUNT(quic_h3_goaway_before_encryption_established); - connection()->CloseConnection( - error_code, reason, - ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); - return; - } + if (!IsEncryptionEstablished()) { + QUIC_CODE_COUNT(quic_h3_goaway_before_encryption_established); + connection()->CloseConnection( + error_code, reason, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; } QuicStreamId stream_id; @@ -858,10 +817,6 @@ void QuicSpdySession::WritePushPromise(QuicStreamId original_stream_id, absl::string_view(frame.data(), frame.size()), false, nullptr); } -bool QuicSpdySession::server_push_enabled() const { - return VersionUsesHttp3(transport_version()) ? false : server_push_enabled_; -} - void QuicSpdySession::SendInitialData() { if (!VersionUsesHttp3(transport_version())) { return; @@ -910,8 +865,7 @@ QuicSpdyStream* QuicSpdySession::GetOrCreateSpdyDataStream( } void QuicSpdySession::OnNewEncryptionKeyAvailable( - EncryptionLevel level, - std::unique_ptr encrypter) { + EncryptionLevel level, std::unique_ptr encrypter) { QuicSession::OnNewEncryptionKeyAvailable(level, std::move(encrypter)); if (IsEncryptionEstablished()) { // Send H3 SETTINGs once encryption is established. @@ -919,13 +873,15 @@ void QuicSpdySession::OnNewEncryptionKeyAvailable( } } -bool QuicSpdySession::ShouldNegotiateWebTransport() { - return false; -} +bool QuicSpdySession::ShouldNegotiateWebTransport() { return false; } + +bool QuicSpdySession::ShouldNegotiateDatagramContexts() { return false; } + +bool QuicSpdySession::ShouldValidateWebTransportVersion() const { return true; } bool QuicSpdySession::WillNegotiateWebTransport() { - return ShouldNegotiateHttp3Datagram() && version().UsesHttp3() && - ShouldNegotiateWebTransport(); + return LocalHttpDatagramSupport() != HttpDatagramSupport::kNone && + version().UsesHttp3() && ShouldNegotiateWebTransport(); } // True if there are open HTTP requests. @@ -935,20 +891,20 @@ bool QuicSpdySession::ShouldKeepConnectionAlive() const { return GetNumActiveStreams() + pending_streams_size() > 0; } -bool QuicSpdySession::UsesPendingStreams() const { - // QuicSpdySession supports PendingStreams, therefore this method should - // eventually just return true. However, pending streams can only be used if - // unidirectional stream type is supported. - return VersionUsesHttp3(transport_version()); +bool QuicSpdySession::UsesPendingStreamForFrame(QuicFrameType type, + QuicStreamId stream_id) const { + // Pending streams can only be used to handle unidirectional stream with + // STREAM & RESET_STREAM frames in IETF QUIC. + return VersionUsesHttp3(transport_version()) && + (type == STREAM_FRAME || type == RST_STREAM_FRAME) && + QuicUtils::GetStreamType(stream_id, perspective(), + IsIncomingStream(stream_id), + version()) == READ_UNIDIRECTIONAL; } size_t QuicSpdySession::WriteHeadersOnHeadersStreamImpl( - QuicStreamId id, - spdy::SpdyHeaderBlock headers, - bool fin, - QuicStreamId parent_stream_id, - int weight, - bool exclusive, + QuicStreamId id, spdy::SpdyHeaderBlock headers, bool fin, + QuicStreamId parent_stream_id, int weight, bool exclusive, QuicReferenceCountedPointer ack_listener) { QUICHE_DCHECK(!VersionUsesHttp3(transport_version())); @@ -982,10 +938,8 @@ size_t QuicSpdySession::WriteHeadersOnHeadersStreamImpl( } void QuicSpdySession::OnPromiseHeaderList( - QuicStreamId /*stream_id*/, - QuicStreamId /*promised_stream_id*/, - size_t /*frame_len*/, - const QuicHeaderList& /*header_list*/) { + QuicStreamId /*stream_id*/, QuicStreamId /*promised_stream_id*/, + size_t /*frame_len*/, const QuicHeaderList& /*header_list*/) { std::string error = "OnPromiseHeaderList should be overridden in client code."; QUIC_BUG(quic_bug_10360_6) << error; @@ -1015,8 +969,7 @@ bool QuicSpdySession::ResumeApplicationState(ApplicationState* cached_state) { } absl::optional QuicSpdySession::OnAlpsData( - const uint8_t* alps_data, - size_t alps_length) { + const uint8_t* alps_data, size_t alps_length) { AlpsFrameDecoder alps_frame_decoder(this); HttpDecoder decoder(&alps_frame_decoder); decoder.ProcessInput(reinterpret_cast(alps_data), alps_length); @@ -1084,6 +1037,19 @@ absl::optional QuicSpdySession::OnSettingsFrameViaAlps( return absl::nullopt; } +bool QuicSpdySession::VerifySettingIsZeroOrOne(uint64_t id, uint64_t value) { + if (value == 0 || value == 1) { + return true; + } + std::string error_details = absl::StrCat( + "Received ", + H3SettingsToString(static_cast(id)), + " with invalid value ", value); + QUIC_PEER_BUG(bad received setting) << ENDPOINT << error_details; + CloseConnectionWithDetails(QUIC_HTTP_INVALID_SETTING_VALUE, error_details); + return false; +} + bool QuicSpdySession::OnSetting(uint64_t id, uint64_t value) { any_settings_received_ = true; @@ -1156,6 +1122,18 @@ bool QuicSpdySession::OnSetting(uint64_t id, uint64_t value) { } break; } + case SETTINGS_ENABLE_CONNECT_PROTOCOL: { + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_ENABLE_CONNECT_PROTOCOL received with value " + << value; + if (!VerifySettingIsZeroOrOne(id, value)) { + return false; + } + if (perspective() == Perspective::IS_CLIENT) { + allow_extended_connect_ = value != 0; + } + break; + } case spdy::SETTINGS_ENABLE_PUSH: ABSL_FALLTHROUGH_INTENDED; case spdy::SETTINGS_MAX_CONCURRENT_STREAMS: @@ -1168,24 +1146,47 @@ bool QuicSpdySession::OnSetting(uint64_t id, uint64_t value) { absl::StrCat("received HTTP/2 specific setting in HTTP/3 session: ", id)); return false; - case SETTINGS_H3_DATAGRAM: { - if (!ShouldNegotiateHttp3Datagram()) { + case SETTINGS_H3_DATAGRAM_DRAFT00: { + HttpDatagramSupport local_http_datagram_support = + LocalHttpDatagramSupport(); + if (local_http_datagram_support != HttpDatagramSupport::kDraft00 && + local_http_datagram_support != HttpDatagramSupport::kDraft00And04) { + break; + } + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_H3_DATAGRAM_DRAFT00 received with value " + << value; + if (!version().UsesHttp3()) { + break; + } + if (!VerifySettingIsZeroOrOne(id, value)) { + return false; + } + if (value && http_datagram_support_ != HttpDatagramSupport::kDraft04) { + // If both draft-00 and draft-04 are supported, use draft-04. + http_datagram_support_ = HttpDatagramSupport::kDraft00; + } + break; + } + case SETTINGS_H3_DATAGRAM_DRAFT04: { + HttpDatagramSupport local_http_datagram_support = + LocalHttpDatagramSupport(); + if (local_http_datagram_support != HttpDatagramSupport::kDraft04 && + local_http_datagram_support != HttpDatagramSupport::kDraft00And04) { break; } - QUIC_DVLOG(1) << ENDPOINT << "SETTINGS_H3_DATAGRAM received with value " + QUIC_DVLOG(1) << ENDPOINT + << "SETTINGS_H3_DATAGRAM_DRAFT04 received with value " << value; if (!version().UsesHttp3()) { break; } - if (value != 0 && value != 1) { - std::string error_details = absl::StrCat( - "received SETTINGS_H3_DATAGRAM with invalid value ", value); - QUIC_PEER_BUG(quic_peer_bug_10360_7) << ENDPOINT << error_details; - CloseConnectionWithDetails(QUIC_HTTP_RECEIVE_SPDY_SETTING, - error_details); + if (!VerifySettingIsZeroOrOne(id, value)) { return false; } - h3_datagram_supported_ = !!value; + if (value) { + http_datagram_support_ = HttpDatagramSupport::kDraft04; + } break; } case SETTINGS_WEBTRANS_DRAFT00: @@ -1195,17 +1196,13 @@ bool QuicSpdySession::OnSetting(uint64_t id, uint64_t value) { QUIC_DVLOG(1) << ENDPOINT << "SETTINGS_ENABLE_WEBTRANSPORT received with value " << value; - if (value != 0 && value != 1) { - std::string error_details = absl::StrCat( - "received SETTINGS_ENABLE_WEBTRANSPORT with invalid value ", - value); - QUIC_PEER_BUG(invalid SETTINGS_ENABLE_WEBTRANSPORT value) - << ENDPOINT << error_details; - CloseConnectionWithDetails(QUIC_HTTP_RECEIVE_SPDY_SETTING, - error_details); + if (!VerifySettingIsZeroOrOne(id, value)) { return false; } peer_supports_webtransport_ = (value == 1); + if (perspective() == Perspective::IS_CLIENT && value == 1) { + allow_extended_connect_ = true; + } break; default: QUIC_DVLOG(1) << ENDPOINT << "Unknown setting identifier " << id @@ -1239,8 +1236,7 @@ bool QuicSpdySession::OnSetting(uint64_t id, uint64_t value) { return true; } QUIC_DVLOG(1) << ENDPOINT << "SETTINGS_ENABLE_PUSH received with value " - << value; - server_push_enabled_ = value; + << value << ", ignoring."; break; } else { QUIC_DLOG(ERROR) @@ -1276,8 +1272,7 @@ bool QuicSpdySession::ShouldReleaseHeadersStreamSequencerBuffer() { return false; } -void QuicSpdySession::OnHeaders(SpdyStreamId stream_id, - bool has_priority, +void QuicSpdySession::OnHeaders(SpdyStreamId stream_id, bool has_priority, const spdy::SpdyStreamPrecedence& precedence, bool fin) { if (has_priority) { @@ -1401,8 +1396,9 @@ QuicStream* QuicSpdySession::ProcessPendingStream(PendingStream* pending) { return receive_control_stream_; } case kServerPushStream: { // Push Stream. - QuicSpdyStream* stream = CreateIncomingStream(pending); - return stream; + CloseConnectionWithDetails(QUIC_HTTP_RECEIVE_SERVER_PUSH, + "Received server push stream"); + return nullptr; } case kQpackEncoderStream: { // QPACK encoder stream. if (qpack_encoder_receive_stream_) { @@ -1458,7 +1454,9 @@ QuicStream* QuicSpdySession::ProcessPendingStream(PendingStream* pending) { default: break; } - MaybeSendStopSendingFrame(pending->id(), QUIC_STREAM_STREAM_CREATION_ERROR); + MaybeSendStopSendingFrame( + pending->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_STREAM_CREATION_ERROR)); pending->StopReading(); return nullptr; } @@ -1611,9 +1609,7 @@ void QuicSpdySession::CloseConnectionOnDuplicateHttp3UnidirectionalStreams( // static void QuicSpdySession::LogHeaderCompressionRatioHistogram( - bool using_qpack, - bool is_sent, - QuicByteCount compressed, + bool using_qpack, bool is_sent, QuicByteCount compressed, QuicByteCount uncompressed) { if (compressed <= 0 || uncompressed <= 0) { return; @@ -1655,81 +1651,122 @@ void QuicSpdySession::LogHeaderCompressionRatioHistogram( } } -QuicDatagramFlowId QuicSpdySession::GetNextDatagramFlowId() { - QuicDatagramFlowId result = next_available_datagram_flow_id_; - next_available_datagram_flow_id_ += kDatagramFlowIdIncrement; - return result; -} - -MessageStatus QuicSpdySession::SendHttp3Datagram(QuicDatagramFlowId flow_id, - absl::string_view payload) { +MessageStatus QuicSpdySession::SendHttp3Datagram( + QuicDatagramStreamId stream_id, + absl::optional context_id, + absl::string_view payload) { + if (!SupportsH3Datagram()) { + QUIC_BUG(send http datagram too early) + << "Refusing to send HTTP Datagram before SETTINGS received"; + return MESSAGE_STATUS_INTERNAL_ERROR; + } + uint64_t stream_id_to_write = stream_id; + if (http_datagram_support_ != HttpDatagramSupport::kDraft00) { + // Stream ID is sent divided by four as per the specification. + stream_id_to_write /= kHttpDatagramStreamIdDivisor; + } size_t slice_length = - QuicDataWriter::GetVarInt62Len(flow_id) + payload.length(); - QuicUniqueBufferPtr buffer = MakeUniqueBuffer( - connection()->helper()->GetStreamSendBufferAllocator(), slice_length); - QuicDataWriter writer(slice_length, buffer.get()); - if (!writer.WriteVarInt62(flow_id)) { - QUIC_BUG(quic_bug_10360_10) << "Failed to write HTTP/3 datagram flow ID"; + QuicDataWriter::GetVarInt62Len(stream_id_to_write) + payload.length(); + if (context_id.has_value()) { + slice_length += QuicDataWriter::GetVarInt62Len(context_id.value()); + } + QuicBuffer buffer(connection()->helper()->GetStreamSendBufferAllocator(), + slice_length); + QuicDataWriter writer(slice_length, buffer.data()); + if (!writer.WriteVarInt62(stream_id_to_write)) { + QUIC_BUG(h3 datagram stream ID write fail) + << "Failed to write HTTP/3 datagram stream ID"; return MESSAGE_STATUS_INTERNAL_ERROR; } + if (context_id.has_value()) { + if (!writer.WriteVarInt62(context_id.value())) { + QUIC_BUG(h3 datagram context ID write fail) + << "Failed to write HTTP/3 datagram context ID"; + return MESSAGE_STATUS_INTERNAL_ERROR; + } + } if (!writer.WriteBytes(payload.data(), payload.length())) { - QUIC_BUG(quic_bug_10360_11) << "Failed to write HTTP/3 datagram payload"; + QUIC_BUG(h3 datagram payload write fail) + << "Failed to write HTTP/3 datagram payload"; return MESSAGE_STATUS_INTERNAL_ERROR; } - QuicMemSlice slice(std::move(buffer), slice_length); + QuicMemSlice slice(std::move(buffer)); return datagram_queue()->SendOrQueueDatagram(std::move(slice)); } -void QuicSpdySession::RegisterHttp3FlowId( - QuicDatagramFlowId flow_id, - QuicSpdySession::Http3DatagramVisitor* visitor) { - QUICHE_DCHECK_NE(visitor, nullptr); - auto insertion_result = h3_datagram_registrations_.insert({flow_id, visitor}); - QUIC_BUG_IF(quic_bug_12477_7, !insertion_result.second) - << "Attempted to doubly register HTTP/3 flow ID " << flow_id; +void QuicSpdySession::SetMaxDatagramTimeInQueueForStreamId( + QuicStreamId /*stream_id*/, QuicTime::Delta max_time_in_queue) { + // TODO(b/184598230): implement this in a way that works for multiple sessions + // on a same connection. + datagram_queue()->SetMaxTimeInQueue(max_time_in_queue); } -void QuicSpdySession::UnregisterHttp3FlowId(QuicDatagramFlowId flow_id) { - size_t num_erased = h3_datagram_registrations_.erase(flow_id); - QUIC_BUG_IF(quic_bug_12477_8, num_erased != 1) - << "Attempted to unregister unknown HTTP/3 flow ID " << flow_id; +void QuicSpdySession::RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id, + QuicStreamId stream_id) { + h3_datagram_flow_id_to_stream_id_map_[flow_id] = stream_id; } -void QuicSpdySession::SetMaxTimeInQueueForFlowId( - QuicDatagramFlowId /*flow_id*/, - QuicTime::Delta max_time_in_queue) { - // TODO(b/184598230): implement this in a way that works for multiple sessions - // on a same connection. - datagram_queue()->SetMaxTimeInQueue(max_time_in_queue); +void QuicSpdySession::UnregisterHttp3DatagramFlowId( + QuicDatagramStreamId flow_id) { + h3_datagram_flow_id_to_stream_id_map_.erase(flow_id); } void QuicSpdySession::OnMessageReceived(absl::string_view message) { QuicSession::OnMessageReceived(message); - if (!h3_datagram_supported_) { - QUIC_DLOG(ERROR) << "Ignoring unexpected received HTTP/3 datagram"; + if (!SupportsH3Datagram()) { + QUIC_DLOG(INFO) << "Ignoring unexpected received HTTP/3 datagram"; return; } QuicDataReader reader(message); - QuicDatagramFlowId flow_id; - if (!reader.ReadVarInt62(&flow_id)) { - QUIC_DLOG(ERROR) << "Failed to parse flow ID in received HTTP/3 datagram"; + uint64_t stream_id64; + if (!reader.ReadVarInt62(&stream_id64)) { + QUIC_DLOG(ERROR) << "Failed to parse stream ID in received HTTP/3 datagram"; + return; + } + if (http_datagram_support_ != HttpDatagramSupport::kDraft00) { + // Stream ID is sent divided by four as per the specification. + stream_id64 *= kHttpDatagramStreamIdDivisor; + } + if (perspective() == Perspective::IS_SERVER && + http_datagram_support_ == HttpDatagramSupport::kDraft00) { + auto it = h3_datagram_flow_id_to_stream_id_map_.find(stream_id64); + if (it == h3_datagram_flow_id_to_stream_id_map_.end()) { + QUIC_DLOG(INFO) << "Received unknown HTTP/3 datagram flow ID " + << stream_id64; + return; + } + stream_id64 = it->second; + } + if (stream_id64 > std::numeric_limits::max()) { + // TODO(b/181256914) make this a connection close once we deprecate + // draft-ietf-masque-h3-datagram-00 in favor of later drafts. + QUIC_DLOG(ERROR) << "Received unexpectedly high HTTP/3 datagram stream ID " + << stream_id64; return; } - auto it = h3_datagram_registrations_.find(flow_id); - if (it == h3_datagram_registrations_.end()) { - // TODO(dschinazi) buffer unknown HTTP/3 datagram flow IDs for a short + QuicStreamId stream_id = static_cast(stream_id64); + QuicSpdyStream* stream = + static_cast(GetActiveStream(stream_id)); + if (stream == nullptr) { + QUIC_DLOG(INFO) << "Received HTTP/3 datagram for unknown stream ID " + << stream_id; + // TODO(b/181256914) buffer unknown HTTP/3 datagram flow IDs for a short // period of time in case they were reordered. - QUIC_DLOG(ERROR) << "Received unknown HTTP/3 datagram flow ID " << flow_id; return; } - absl::string_view payload = reader.ReadRemainingPayload(); - it->second->OnHttp3Datagram(flow_id, payload); + stream->OnDatagramReceived(&reader); } bool QuicSpdySession::SupportsWebTransport() { - return WillNegotiateWebTransport() && h3_datagram_supported_ && - peer_supports_webtransport_; + return WillNegotiateWebTransport() && SupportsH3Datagram() && + peer_supports_webtransport_ && + (!GetQuicReloadableFlag(quic_verify_request_headers_2) || + allow_extended_connect_); +} + +bool QuicSpdySession::SupportsH3Datagram() const { + return http_datagram_support_ != HttpDatagramSupport::kNone; } WebTransportHttp3* QuicSpdySession::GetWebTransportSession( @@ -1762,8 +1799,7 @@ void QuicSpdySession::OnStreamWaitingForClientSettings(QuicStreamId id) { } void QuicSpdySession::AssociateIncomingWebTransportStreamWithSession( - WebTransportSessionId session_id, - QuicStreamId stream_id) { + WebTransportSessionId session_id, QuicStreamId stream_id) { if (QuicUtils::IsOutgoingStreamId(version(), stream_id, perspective())) { QUIC_BUG(AssociateIncomingWebTransportStreamWithSession got outgoing stream) << ENDPOINT @@ -1859,8 +1895,48 @@ void QuicSpdySession::DatagramObserver::OnDatagramProcessed( session_->OnDatagramProcessed(status); } -bool QuicSpdySession::ShouldNegotiateHttp3Datagram() { - return GetQuicReloadableFlag(quic_h3_datagram); +HttpDatagramSupport QuicSpdySession::LocalHttpDatagramSupport() { + return HttpDatagramSupport::kNone; +} + +std::string HttpDatagramSupportToString( + HttpDatagramSupport http_datagram_support) { + switch (http_datagram_support) { + case HttpDatagramSupport::kNone: + return "None"; + case HttpDatagramSupport::kDraft00: + return "Draft00"; + case HttpDatagramSupport::kDraft04: + return "Draft04"; + case HttpDatagramSupport::kDraft00And04: + return "Draft00And04"; + } + return absl::StrCat("Unknown(", static_cast(http_datagram_support), ")"); +} + +std::ostream& operator<<(std::ostream& os, + const HttpDatagramSupport& http_datagram_support) { + os << HttpDatagramSupportToString(http_datagram_support); + return os; +} + +// Must not be called after Initialize(). +void QuicSpdySession::set_allow_extended_connect(bool allow_extended_connect) { + QUIC_BUG_IF(extended connect wrong version, + !GetQuicReloadableFlag(quic_verify_request_headers_2) || + !VersionUsesHttp3(transport_version())) + << "Try to enable/disable extended CONNECT in Google QUIC"; + QUIC_BUG_IF(extended connect on client, + !GetQuicReloadableFlag(quic_verify_request_headers_2) || + perspective() == Perspective::IS_CLIENT) + << "Enabling/disabling extended CONNECT on the client side has no effect"; + if (ShouldNegotiateWebTransport()) { + QUIC_BUG_IF(disable extended connect, !allow_extended_connect) + << "Disabling extended CONNECT with web transport enabled has no " + "effect."; + return; + } + allow_extended_connect_ = allow_extended_connect; } #undef ENDPOINT // undef for jumbo builds diff --git a/gquiche/quic/core/http/quic_spdy_session.h b/gquiche/quic/core/http/quic_spdy_session.h index 0428cdab..1dfa9bca 100644 --- a/gquiche/quic/core/http/quic_spdy_session.h +++ b/gquiche/quic/core/http/quic_spdy_session.h @@ -6,6 +6,7 @@ #define QUICHE_QUIC_CORE_HTTP_QUIC_SPDY_SESSION_H_ #include +#include #include #include #include @@ -26,13 +27,13 @@ #include "gquiche/quic/core/qpack/qpack_encoder_stream_sender.h" #include "gquiche/quic/core/qpack/qpack_receive_stream.h" #include "gquiche/quic/core/qpack/qpack_send_stream.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/common/quiche_circular_deque.h" #include "gquiche/spdy/core/http2_frame_decoder_adapter.h" namespace quic { @@ -78,8 +79,6 @@ class QUIC_EXPORT_PRIVATE Http3DebugVisitor { virtual void OnAcceptChFrameReceivedViaAlps(const AcceptChFrame& /*frame*/) {} // Incoming HTTP/3 frames on the control stream. - // TODO(b/171463363): Remove. - virtual void OnCancelPushFrameReceived(const CancelPushFrame& /*frame*/) {} virtual void OnSettingsFrameReceived(const SettingsFrame& /*frame*/) = 0; virtual void OnGoAwayFrameReceived(const GoAwayFrame& /*frame*/) {} // TODO(b/171463363): Remove. @@ -92,19 +91,10 @@ class QUIC_EXPORT_PRIVATE Http3DebugVisitor { virtual void OnDataFrameReceived(QuicStreamId /*stream_id*/, QuicByteCount /*payload_length*/) {} virtual void OnHeadersFrameReceived( - QuicStreamId /*stream_id*/, - QuicByteCount /*compressed_headers_length*/) {} + QuicStreamId /*stream_id*/, QuicByteCount /*compressed_headers_length*/) { + } virtual void OnHeadersDecoded(QuicStreamId /*stream_id*/, QuicHeaderList /*headers*/) {} - // TODO(b/171463363): Remove. - virtual void OnPushPromiseFrameReceived(QuicStreamId /*stream_id*/, - QuicStreamId /*push_id*/, - QuicByteCount - /*compressed_headers_length*/) {} - // TODO(b/171463363): Remove. - virtual void OnPushPromiseDecoded(QuicStreamId /*stream_id*/, - QuicStreamId /*push_id*/, - QuicHeaderList /*headers*/) {} // Incoming HTTP/3 frames of unknown type on any stream. virtual void OnUnknownFrameReceived(QuicStreamId /*stream_id*/, @@ -125,17 +115,25 @@ class QUIC_EXPORT_PRIVATE Http3DebugVisitor { virtual void OnHeadersFrameSent( QuicStreamId /*stream_id*/, const spdy::SpdyHeaderBlock& /*header_block*/) {} - // TODO(b/171463363): Remove. - virtual void OnPushPromiseFrameSent( - QuicStreamId /*stream_id*/, - QuicStreamId - /*push_id*/, - const spdy::SpdyHeaderBlock& /*header_block*/) {} // 0-RTT related events. virtual void OnSettingsFrameResumed(const SettingsFrame& /*frame*/) {} }; +// Whether draft-ietf-masque-h3-datagram is supported on this session and if so +// which draft is currently in use. +enum class HttpDatagramSupport : uint8_t { + kNone = 0, // HTTP Datagrams are not supported for this session. + kDraft00 = 1, + kDraft04 = 2, + kDraft00And04 = 3, // only used locally, we only negotiate one draft. +}; + +QUIC_EXPORT_PRIVATE std::string HttpDatagramSupportToString( + HttpDatagramSupport http_datagram_support); +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const HttpDatagramSupport& http_datagram_support); + // A QUIC session for HTTP. class QUIC_EXPORT_PRIVATE QuicSpdySession : public QuicSession, @@ -143,8 +141,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession public QpackDecoder::EncoderStreamErrorDelegate { public: // Does not take ownership of |connection| or |visitor|. - QuicSpdySession(QuicConnection* connection, - QuicSession::Visitor* visitor, + QuicSpdySession(QuicConnection* connection, QuicSession::Visitor* visitor, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions); QuicSpdySession(const QuicSpdySession&) = delete; @@ -165,14 +162,12 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // Called by |headers_stream_| when headers with a priority have been // received for a stream. This method will only be called for server streams. virtual void OnStreamHeadersPriority( - QuicStreamId stream_id, - const spdy::SpdyStreamPrecedence& precedence); + QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence); // Called by |headers_stream_| when headers have been completely received // for a stream. |fin| will be true if the fin flag was set in the headers // frame. - virtual void OnStreamHeaderList(QuicStreamId stream_id, - bool fin, + virtual void OnStreamHeaderList(QuicStreamId stream_id, bool fin, size_t frame_len, const QuicHeaderList& header_list); @@ -209,18 +204,14 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // If provided, |ack_notifier_delegate| will be registered to be notified when // we have seen ACKs for all packets resulting from this call. virtual size_t WriteHeadersOnHeadersStream( - QuicStreamId id, - spdy::SpdyHeaderBlock headers, - bool fin, + QuicStreamId id, spdy::SpdyHeaderBlock headers, bool fin, const spdy::SpdyStreamPrecedence& precedence, QuicReferenceCountedPointer ack_listener); // Writes an HTTP/2 PRIORITY frame the to peer. Returns the size in bytes of // the resulting PRIORITY frame. - size_t WritePriority(QuicStreamId id, - QuicStreamId parent_stream_id, - int weight, - bool exclusive); + size_t WritePriority(QuicStreamId id, QuicStreamId parent_stream_id, + int weight, bool exclusive); // Writes an HTTP/3 PRIORITY_UPDATE frame to the peer. void WriteHttp3PriorityUpdate(const PriorityUpdateFrame& priority_update); @@ -251,12 +242,6 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession const QuicHeadersStream* headers_stream() const { return headers_stream_; } - // Returns whether server push is enabled. - // For a Google QUIC client this always returns false. - // For a Google QUIC server this is set by incoming SETTINGS_ENABLE_PUSH. - // For an IETF QUIC client or server this returns false. - bool server_push_enabled() const; - // Called when the control stream receives HTTP/3 SETTINGS. // Returns false in case of 0-RTT if received settings are incompatible with // cached values, true otherwise. @@ -298,12 +283,16 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession qpack_maximum_blocked_streams_ = qpack_maximum_blocked_streams; } + // Should only be used by IETF QUIC server side. // Must not be called after Initialize(). // TODO(bnc): Move to constructor argument. void set_max_inbound_header_list_size(size_t max_inbound_header_list_size) { max_inbound_header_list_size_ = max_inbound_header_list_size; } + // Must not be called after Initialize(). + void set_allow_extended_connect(bool allow_extended_connect); + size_t max_outbound_header_list_size() const { return max_outbound_header_list_size_; } @@ -312,6 +301,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession return max_inbound_header_list_size_; } + bool allow_extended_connect() const { return allow_extended_connect_; } + // Returns true if the session has active request streams. bool HasActiveRequestStreams() const; @@ -372,8 +363,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // In order for measurements for different protocol to be comparable, the // caller must ensure that uncompressed size is the total length of header // names and values without any overhead. - static void LogHeaderCompressionRatioHistogram(bool using_qpack, - bool is_sent, + static void LogHeaderCompressionRatioHistogram(bool using_qpack, bool is_sent, QuicByteCount compressed, QuicByteCount uncompressed); @@ -397,41 +387,26 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // extension. virtual void OnAcceptChFrameReceivedViaAlps(const AcceptChFrame& /*frame*/); - // Generates a new HTTP/3 datagram flow ID. - QuicDatagramFlowId GetNextDatagramFlowId(); - - // Whether HTTP/3 datagrams are supported on this session, based on received - // SETTINGS. - bool h3_datagram_supported() const { return h3_datagram_supported_; } - - // Sends an HTTP/3 datagram. The flow ID is not part of |payload|. - MessageStatus SendHttp3Datagram(QuicDatagramFlowId flow_id, - absl::string_view payload); - - class QUIC_EXPORT_PRIVATE Http3DatagramVisitor { - public: - virtual ~Http3DatagramVisitor() {} - - // Called when an HTTP/3 datagram is received. |payload| does not contain - // the flow ID. - virtual void OnHttp3Datagram(QuicDatagramFlowId flow_id, - absl::string_view payload) = 0; - }; - - // Registers |visitor| to receive HTTP/3 datagrams for flow ID |flow_id|. This - // must not be called on a previously register flow ID without first calling - // UnregisterHttp3FlowId. |visitor| must be valid until a corresponding call - // to UnregisterHttp3FlowId. The flow ID must be unregistered before the - // QuicSpdySession is destroyed. - void RegisterHttp3FlowId(QuicDatagramFlowId flow_id, - Http3DatagramVisitor* visitor); - - // Unregister a given HTTP/3 datagram flow ID. - void UnregisterHttp3FlowId(QuicDatagramFlowId flow_id); + // Whether HTTP datagrams are supported on this session and which draft is in + // use, based on received SETTINGS. + HttpDatagramSupport http_datagram_support() const { + return http_datagram_support_; + } - // Sets max time in queue for a specified datagram flow ID. - void SetMaxTimeInQueueForFlowId(QuicDatagramFlowId flow_id, - QuicTime::Delta max_time_in_queue); + // This must not be used except by QuicSpdyStream::SendHttp3Datagram. + MessageStatus SendHttp3Datagram( + QuicDatagramStreamId stream_id, + absl::optional context_id, + absl::string_view payload); + // This must not be used except by QuicSpdyStream::SetMaxDatagramTimeInQueue. + void SetMaxDatagramTimeInQueueForStreamId(QuicStreamId stream_id, + QuicTime::Delta max_time_in_queue); + // This must not be used except by + // QuicSpdyStream::MaybeProcessReceivedWebTransportHeaders. + void RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id, + QuicStreamId stream_id); + // This must not be used except by QuicSpdyStream::OnClose. + void UnregisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id); // Override from QuicSession to support HTTP/3 datagrams. void OnMessageReceived(absl::string_view message) override; @@ -439,6 +414,9 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // Indicates whether the HTTP/3 session supports WebTransport. bool SupportsWebTransport(); + // Indicates whether both the peer and us support HTTP/3 Datagrams. + bool SupportsH3Datagram() const; + // Indicates whether the HTTP/3 session will indicate WebTransport support to // the peer. bool WillNegotiateWebTransport(); @@ -451,7 +429,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // until the SETTINGS are received. Only works for HTTP/3. bool ShouldBufferRequestsUntilSettings() { return version().UsesHttp3() && perspective() == Perspective::IS_SERVER && - WillNegotiateWebTransport(); + LocalHttpDatagramSupport() != HttpDatagramSupport::kNone; } // Returns if the incoming bidirectional streams should process data. This is @@ -464,8 +442,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // Links the specified stream with a WebTransport session. If the session is // not present, it is buffered until a corresponding stream is found. void AssociateIncomingWebTransportStreamWithSession( - WebTransportSessionId session_id, - QuicStreamId stream_id); + WebTransportSessionId session_id, QuicStreamId stream_id); void ProcessBufferedWebTransportStreamsForSession(WebTransportHttp3* session); @@ -490,6 +467,15 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession QuicSpdyStream* GetOrCreateSpdyDataStream(const QuicStreamId stream_id); + // Indicates whether we will try to negotiate datagram contexts on newly + // created WebTransport sessions over HTTP/3. + virtual bool ShouldNegotiateDatagramContexts(); + + // Indicates whether the client should check that the + // `Sec-Webtransport-Http3-Draft` header is valid. + // TODO(vasilvv): remove this once this is enabled in Chromium. + virtual bool ShouldValidateWebTransportVersion() const; + protected: // Override CreateIncomingStream(), CreateOutgoingBidirectionalStream() and // CreateOutgoingUnidirectionalStream() with QuicSpdyStream return type to @@ -515,7 +501,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession bool ShouldKeepConnectionAlive() const override; // Overridden to buffer incoming unidirectional streams for version 99. - bool UsesPendingStreams() const override; + bool UsesPendingStreamForFrame(QuicFrameType type, + QuicStreamId stream_id) const override; // Processes incoming unidirectional streams; parses the stream type, and // creates a new stream of the corresponding type. Returns the pointer to the @@ -523,17 +510,12 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession QuicStream* ProcessPendingStream(PendingStream* pending) override; size_t WriteHeadersOnHeadersStreamImpl( - QuicStreamId id, - spdy::SpdyHeaderBlock headers, - bool fin, - QuicStreamId parent_stream_id, - int weight, - bool exclusive, + QuicStreamId id, spdy::SpdyHeaderBlock headers, bool fin, + QuicStreamId parent_stream_id, int weight, bool exclusive, QuicReferenceCountedPointer ack_listener); void OnNewEncryptionKeyAvailable( - EncryptionLevel level, - std::unique_ptr encrypter) override; + EncryptionLevel level, std::unique_ptr encrypter) override; // Sets the maximum size of the header compression table spdy_framer_ is // willing to use to encode header blocks. @@ -556,8 +538,9 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // Called whenever a datagram is dequeued or dropped from datagram_queue(). virtual void OnDatagramProcessed(absl::optional status); - // Returns true if HTTP/3 datagram extension should be supported. - virtual bool ShouldNegotiateHttp3Datagram(); + // Returns which version of the HTTP/3 datagram extension we should advertise + // in settings and accept remote settings for. + virtual HttpDatagramSupport LocalHttpDatagramSupport(); private: friend class test::QuicSpdySessionPeer; @@ -583,10 +566,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // The following methods are called by the SimpleVisitor. // Called when a HEADERS frame has been received. - void OnHeaders(spdy::SpdyStreamId stream_id, - bool has_priority, - const spdy::SpdyStreamPrecedence& precedence, - bool fin); + void OnHeaders(spdy::SpdyStreamId stream_id, bool has_priority, + const spdy::SpdyStreamPrecedence& precedence, bool fin); // Called when a PRIORITY frame has been received. void OnPriority(spdy::SpdyStreamId stream_id, @@ -604,6 +585,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession void FillSettingsFrame(); + bool VerifySettingIsZeroOrOne(uint64_t id, uint64_t value); + std::unique_ptr qpack_encoder_; std::unique_ptr qpack_decoder_; @@ -684,10 +667,6 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // an use-after-free. int32_t destruction_indicator_; - // Used in Google QUIC only. Set every time SETTINGS_ENABLE_PUSH is received. - // Defaults to true. - bool server_push_enabled_; - // The identifier in the most recently received GOAWAY frame. Unset if no // GOAWAY frame has been received yet. absl::optional last_received_http3_goaway_id_; @@ -695,18 +674,18 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // frame has been sent yet. absl::optional last_sent_http3_goaway_id_; - // Value of the smallest unused HTTP/3 datagram flow ID that this endpoint's - // datagram flow ID allocation service will use next. - QuicDatagramFlowId next_available_datagram_flow_id_; - - // Whether both this endpoint and our peer support HTTP/3 datagrams. - bool h3_datagram_supported_ = false; + // Whether both this endpoint and our peer support HTTP datagrams and which + // draft is in use for this session. + HttpDatagramSupport http_datagram_support_ = HttpDatagramSupport::kNone; // Whether the peer has indicated WebTransport support. bool peer_supports_webtransport_ = false; - absl::flat_hash_map - h3_datagram_registrations_; + // This maps from draft-ietf-masque-h3-datagram-00 flow IDs to stream IDs. + // TODO(b/181256914) remove this when we deprecate support for that draft in + // favor of more recent ones. + absl::flat_hash_map + h3_datagram_flow_id_to_stream_id_map_; // Whether any settings have been received, either from the peer or from a // session ticket. @@ -720,6 +699,10 @@ class QUIC_EXPORT_PRIVATE QuicSpdySession // Limited to kMaxUnassociatedWebTransportStreams; when the list is full, // oldest streams are evicated first. std::list buffered_streams_; + + // On the server side, if true, advertise and accept extended CONNECT method. + // On the client side, true if the peer advertised extended CONNECT. + bool allow_extended_connect_; }; } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_session_test.cc b/gquiche/quic/core/http/quic_spdy_session_test.cc index 2db11c62..a4a92eb9 100644 --- a/gquiche/quic/core/http/quic_spdy_session_test.cc +++ b/gquiche/quic/core/http/quic_spdy_session_test.cc @@ -35,7 +35,6 @@ #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_expect_bug.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_encoder_peer.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" @@ -47,7 +46,6 @@ #include "gquiche/quic/test_tools/quic_stream_peer.h" #include "gquiche/quic/test_tools/quic_stream_send_buffer_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/quiche_endian.h" #include "gquiche/common/test_tools/quiche_test_utils.h" #include "gquiche/spdy/core/spdy_framer.h" @@ -175,10 +173,18 @@ class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { void OnHandshakePacketSent() override {} void OnHandshakeDoneReceived() override {} void OnNewTokenReceived(absl::string_view /*token*/) override {} - std::string GetAddressToken() const override { return ""; } + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_params*/) const override { + return ""; + } bool ValidateAddressToken(absl::string_view /*token*/) const override { return true; } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} MOCK_METHOD(void, OnCanWrite, (), (override)); @@ -188,6 +194,14 @@ class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { void OnConnectionClosed(QuicErrorCode /*error*/, ConnectionCloseSource /*source*/) override {} + SSL* GetSsl() const override { return nullptr; } + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } private: using QuicCryptoStream::session; @@ -210,32 +224,33 @@ class TestStream : public QuicSpdyStream { TestStream(QuicStreamId id, QuicSpdySession* session, StreamType type) : QuicSpdyStream(id, session, type) {} - TestStream(PendingStream* pending, QuicSpdySession* session, StreamType type) - : QuicSpdyStream(pending, session, type) {} + TestStream(PendingStream* pending, QuicSpdySession* session) + : QuicSpdyStream(pending, session) {} using QuicStream::CloseWriteSide; void OnBodyAvailable() override {} MOCK_METHOD(void, OnCanWrite, (), (override)); - MOCK_METHOD(bool, - RetransmitStreamData, + MOCK_METHOD(bool, RetransmitStreamData, (QuicStreamOffset, QuicByteCount, bool, TransmissionType), (override)); MOCK_METHOD(bool, HasPendingRetransmission, (), (const, override)); + + protected: + bool AreHeadersValid(const QuicHeaderList& /*header_list*/) const override { + return true; + } }; class TestSession : public QuicSpdySession { public: explicit TestSession(QuicConnection* connection) - : QuicSpdySession(connection, - nullptr, - DefaultQuicConfig(), + : QuicSpdySession(connection, nullptr, DefaultQuicConfig(), CurrentSupportedVersions()), crypto_stream_(this), writev_consumes_all_data_(false) { - Initialize(); this->connection()->SetEncrypter( ENCRYPTION_FORWARD_SECURE, std::make_unique(connection->perspective())); @@ -288,11 +303,7 @@ class TestSession : public QuicSpdySession { } TestStream* CreateIncomingStream(PendingStream* pending) override { - QuicStreamId id = pending->id(); - TestStream* stream = new TestStream( - pending, this, - DetermineStreamType(id, connection()->version(), perspective(), - /*is_incoming=*/true, BIDIRECTIONAL)); + TestStream* stream = new TestStream(pending, this); ActivateStream(absl::WrapUnique(stream)); return stream; } @@ -310,12 +321,10 @@ class TestSession : public QuicSpdySession { return QuicSpdySession::GetOrCreateStream(stream_id); } - QuicConsumedData WritevData(QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, TransmissionType type, - absl::optional level) override { + EncryptionLevel level) override { bool fin = state != NO_FIN; QuicConsumedData consumed(write_length, fin); if (!writev_consumes_all_data_) { @@ -357,18 +366,26 @@ class TestSession : public QuicSpdySession { bool ShouldNegotiateWebTransport() override { return supports_webtransport_; } void set_supports_webtransport(bool value) { supports_webtransport_ = value; } + HttpDatagramSupport LocalHttpDatagramSupport() override { + return local_http_datagram_support_; + } + void set_local_http_datagram_support(HttpDatagramSupport value) { + local_http_datagram_support_ = value; + } + MOCK_METHOD(void, OnAcceptChFrame, (const AcceptChFrame&), (override)); using QuicSession::closed_streams; using QuicSession::ShouldKeepConnectionAlive; using QuicSpdySession::ProcessPendingStream; - using QuicSpdySession::UsesPendingStreams; + using QuicSpdySession::UsesPendingStreamForFrame; private: StrictMock crypto_stream_; bool writev_consumes_all_data_; bool supports_webtransport_ = false; + HttpDatagramSupport local_http_datagram_support_ = HttpDatagramSupport::kNone; }; class QuicSpdySessionTestBase : public QuicTestWithParam { @@ -382,13 +399,18 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { } protected: - explicit QuicSpdySessionTestBase(Perspective perspective) - : connection_( - new StrictMock(&helper_, - &alarm_factory_, - perspective, - SupportedVersions(GetParam()))), + explicit QuicSpdySessionTestBase(Perspective perspective, + bool allow_extended_connect) + : connection_(new StrictMock( + &helper_, &alarm_factory_, perspective, + SupportedVersions(GetParam()))), session_(connection_) { + if (perspective == Perspective::IS_SERVER && + VersionUsesHttp3(transport_version()) && + GetQuicReloadableFlag(quic_verify_request_headers_2)) { + session_.set_allow_extended_connect(allow_extended_connect); + } + session_.Initialize(); session_.config()->SetInitialStreamFlowControlWindowToSend( kInitialStreamFlowControlWindowForTest); session_.config()->SetInitialSessionFlowControlWindowToSend( @@ -421,7 +443,7 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { first_stream_id = QuicUtils::GetCryptoStreamId(transport_version()); } for (QuicStreamId i = first_stream_id; i < 100; i++) { - if (!QuicContainsKey(closed_streams_, i)) { + if (closed_streams_.find(i) == closed_streams_.end()) { EXPECT_FALSE(session_.IsClosedStream(i)) << " stream id: " << i; } else { EXPECT_TRUE(session_.IsClosedStream(i)) << " stream id: " << i; @@ -506,8 +528,7 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { } QuicStreamId StreamCountToId(QuicStreamCount stream_count, - Perspective perspective, - bool bidirectional) { + Perspective perspective, bool bidirectional) { // Calculate and build up stream ID rather than use // GetFirst... because the test that relies on this method // needs to do the stream count where #1 is 0/1/2/3, and not @@ -543,8 +564,9 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { void ReceiveWebTransportSettings() { SettingsFrame settings; - settings.values[SETTINGS_H3_DATAGRAM] = 1; + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; settings.values[SETTINGS_WEBTRANS_DRAFT00] = 1; + settings.values[SETTINGS_ENABLE_CONNECT_PROTOCOL] = 1; std::string data = std::string(1, kControlStream) + EncodeSettings(settings); QuicStreamId control_stream_id = @@ -558,7 +580,6 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { } void ReceiveWebTransportSession(WebTransportSessionId session_id) { - SetQuicReloadableFlag(quic_accept_empty_stream_frame_with_no_fin, true); QuicStreamFrame frame(session_id, /*fin=*/false, /*offset=*/0, absl::string_view()); session_.OnStreamFrame(frame); @@ -568,9 +589,16 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { headers.OnHeaderBlockStart(); headers.OnHeader(":method", "CONNECT"); headers.OnHeader(":protocol", "webtransport"); - headers.OnHeader("datagram-flow-id", - absl::StrCat(session_.GetNextDatagramFlowId())); + if (session_.http_datagram_support() == HttpDatagramSupport::kDraft00) { + headers.OnHeader("datagram-flow-id", absl::StrCat(session_id)); + } else { + headers.OnHeader("sec-webtransport-http3-draft02", "1"); + } stream->OnStreamHeaderList(/*fin=*/true, 0, headers); + if (session_.http_datagram_support() != HttpDatagramSupport::kDraft00) { + stream->OnCapsule( + Capsule::RegisterDatagramNoContext(DatagramFormatType::WEBTRANSPORT)); + } WebTransportHttp3* web_transport = session_.GetWebTransportSession(session_id); ASSERT_TRUE(web_transport != nullptr); @@ -590,6 +618,11 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { session_.OnStreamFrame(frame); } + void TestHttpDatagramSetting(HttpDatagramSupport local_support, + HttpDatagramSupport remote_support, + HttpDatagramSupport expected_support, + bool expected_datagram_supported); + MockQuicConnectionHelper helper_; MockAlarmFactory alarm_factory_; StrictMock* connection_; @@ -601,19 +634,32 @@ class QuicSpdySessionTestBase : public QuicTestWithParam { class QuicSpdySessionTestServer : public QuicSpdySessionTestBase { protected: QuicSpdySessionTestServer() - : QuicSpdySessionTestBase(Perspective::IS_SERVER) {} + : QuicSpdySessionTestBase(Perspective::IS_SERVER, true) {} }; -INSTANTIATE_TEST_SUITE_P(Tests, - QuicSpdySessionTestServer, +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdySessionTestServer, ::testing::ValuesIn(AllSupportedVersions()), ::testing::PrintToStringParamName()); -TEST_P(QuicSpdySessionTestServer, UsesPendingStreams) { +TEST_P(QuicSpdySessionTestServer, UsesPendingStreamsForFrame) { if (!VersionUsesHttp3(transport_version())) { return; } - EXPECT_TRUE(session_.UsesPendingStreams()); + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + STOP_SENDING_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); } TEST_P(QuicSpdySessionTestServer, PeerAddress) { @@ -1177,7 +1223,6 @@ TEST_P(QuicSpdySessionTestServer, SendGoAway) { } TEST_P(QuicSpdySessionTestServer, SendGoAwayWithoutEncryption) { - SetQuicReloadableFlag(quic_encrypted_goaway, true); if (VersionHasIetfQuicFrames(transport_version())) { // HTTP/3 GOAWAY has different semantic and thus has its own test. return; @@ -1219,7 +1264,6 @@ TEST_P(QuicSpdySessionTestServer, SendHttp3GoAway) { } TEST_P(QuicSpdySessionTestServer, SendHttp3GoAwayWithoutEncryption) { - SetQuicReloadableFlag(quic_encrypted_goaway, true); if (!VersionUsesHttp3(transport_version())) { return; } @@ -1499,61 +1543,6 @@ TEST_P(QuicSpdySessionTestServer, HandshakeUnblocksFlowControlBlockedStream) { EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); } -TEST_P(QuicSpdySessionTestServer, - HandshakeUnblocksFlowControlBlockedCryptoStream) { - if (QuicVersionUsesCryptoFrames(transport_version()) || - connection_->encrypted_control_frames()) { - // QUIC version 47 onwards uses CRYPTO frames for the handshake, so this - // test doesn't make sense for those versions. With - // use_encryption_level_context, control frames can only be sent when - // encryption gets established, do not send BLOCKED for crypto streams. - return; - } - // Test that if the crypto stream is flow control blocked, then if the SHLO - // contains a larger send window offset, the stream becomes unblocked. - session_.set_writev_consumes_all_data(true); - TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); - EXPECT_FALSE(crypto_stream->IsFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); - QuicHeadersStream* headers_stream = - QuicSpdySessionPeer::GetHeadersStream(&session_); - EXPECT_FALSE(headers_stream->IsFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); - EXPECT_CALL(*connection_, SendControlFrame(_)) - .WillOnce(Invoke(&ClearControlFrame)); - for (QuicStreamId i = 0; !crypto_stream->IsFlowControlBlocked() && i < 1000u; - i++) { - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); - QuicStreamOffset offset = crypto_stream->stream_bytes_written(); - QuicConfig config; - CryptoHandshakeMessage crypto_message; - config.ToHandshakeMessage(&crypto_message, transport_version()); - crypto_stream->SendHandshakeMessage(crypto_message, ENCRYPTION_INITIAL); - char buf[1000]; - QuicDataWriter writer(1000, buf, quiche::NETWORK_BYTE_ORDER); - crypto_stream->WriteStreamData(offset, crypto_message.size(), &writer); - } - EXPECT_TRUE(crypto_stream->IsFlowControlBlocked()); - EXPECT_FALSE(headers_stream->IsFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_TRUE(session_.IsStreamFlowControlBlocked()); - EXPECT_FALSE(session_.HasDataToWrite()); - EXPECT_TRUE(crypto_stream->HasBufferedData()); - - // Now complete the crypto handshake, resulting in an increased flow control - // send window. - CompleteHandshake(); - EXPECT_TRUE(QuicSessionPeer::IsStreamWriteBlocked( - &session_, QuicUtils::GetCryptoStreamId(transport_version()))); - // Stream is now unblocked and will no longer have buffered data. - EXPECT_FALSE(crypto_stream->IsFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); -} - #if !defined(OS_IOS) // This test is failing flakily for iOS bots. // http://crbug.com/425050 @@ -1891,19 +1880,32 @@ TEST_P(QuicSpdySessionTestServer, ReduceMaxPushId) { class QuicSpdySessionTestClient : public QuicSpdySessionTestBase { protected: QuicSpdySessionTestClient() - : QuicSpdySessionTestBase(Perspective::IS_CLIENT) {} + : QuicSpdySessionTestBase(Perspective::IS_CLIENT, false) {} }; -INSTANTIATE_TEST_SUITE_P(Tests, - QuicSpdySessionTestClient, +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdySessionTestClient, ::testing::ValuesIn(AllSupportedVersions()), ::testing::PrintToStringParamName()); -TEST_P(QuicSpdySessionTestClient, UsesPendingStreams) { +TEST_P(QuicSpdySessionTestClient, UsesPendingStreamsForFrame) { if (!VersionUsesHttp3(transport_version())) { return; } - EXPECT_TRUE(session_.UsesPendingStreams()); + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_TRUE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + STOP_SENDING_FRAME, QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); + EXPECT_FALSE(session_.UsesPendingStreamForFrame( + RST_STREAM_FRAME, QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_SERVER))); } // Regression test for crbug.com/977581. @@ -1917,12 +1919,6 @@ TEST_P(QuicSpdySessionTestClient, BadStreamFramePendingStream) { GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); // A bad stream frame with no data and no fin. QuicStreamFrame data1(stream_id1, false, 0, 0); - if (!GetQuicReloadableFlag(quic_accept_empty_stream_frame_with_no_fin)) { - EXPECT_CALL(*connection_, CloseConnection(_, _, _)) - .WillOnce( - Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); - EXPECT_CALL(*connection_, SendConnectionClosePacket(_, _, _)); - } session_.OnStreamFrame(data1); } @@ -2068,25 +2064,11 @@ TEST_P(QuicSpdySessionTestClient, Http3ServerPush) { std::string frame_type1 = absl::HexStringToBytes("01"); QuicStreamId stream_id1 = GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_RECEIVE_SERVER_PUSH, _, _)) + .Times(1); session_.OnStreamFrame(QuicStreamFrame(stream_id1, /* fin = */ false, /* offset = */ 0, frame_type1)); - - EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); - QuicStream* stream = session_.GetOrCreateStream(stream_id1); - EXPECT_EQ(1u, QuicStreamPeer::bytes_consumed(stream)); - EXPECT_EQ(1u, session_.flow_controller()->bytes_consumed()); - - // The same stream type can be encoded differently. - std::string frame_type2 = absl::HexStringToBytes("80000001"); - QuicStreamId stream_id2 = - GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 1); - session_.OnStreamFrame(QuicStreamFrame(stream_id2, /* fin = */ false, - /* offset = */ 0, frame_type2)); - - EXPECT_EQ(2u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); - stream = session_.GetOrCreateStream(stream_id2); - EXPECT_EQ(4u, QuicStreamPeer::bytes_consumed(stream)); - EXPECT_EQ(5u, session_.flow_controller()->bytes_consumed()); } TEST_P(QuicSpdySessionTestClient, Http3ServerPushOutofOrderFrame) { @@ -2113,11 +2095,10 @@ TEST_P(QuicSpdySessionTestClient, Http3ServerPushOutofOrderFrame) { // Receiving some stream data without stream type does not open the stream. session_.OnStreamFrame(data2); EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); - + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_RECEIVE_SERVER_PUSH, _, _)) + .Times(1); session_.OnStreamFrame(data1); - EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); - QuicStream* stream = session_.GetOrCreateStream(stream_id); - EXPECT_EQ(3u, stream->highest_received_byte_offset()); } TEST_P(QuicSpdySessionTestServer, OnStreamFrameLost) { @@ -2626,10 +2607,10 @@ TEST_P(QuicSpdySessionTestClient, ResetAfterInvalidIncomingStreamType) { return; } CompleteHandshake(); - ASSERT_TRUE(session_.UsesPendingStreams()); const QuicStreamId stream_id = GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); // Payload consists of two bytes. The first byte is an unknown unidirectional // stream type. The second one would be the type of a push stream, but it @@ -2673,10 +2654,10 @@ TEST_P(QuicSpdySessionTestClient, FinAfterInvalidIncomingStreamType) { return; } CompleteHandshake(); - ASSERT_TRUE(session_.UsesPendingStreams()); const QuicStreamId stream_id = GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); // Payload consists of two bytes. The first byte is an unknown unidirectional // stream type. The second one would be the type of a push stream, but it @@ -2712,10 +2693,10 @@ TEST_P(QuicSpdySessionTestClient, ResetInMiddleOfStreamType) { if (!VersionUsesHttp3(transport_version())) { return; } - ASSERT_TRUE(session_.UsesPendingStreams()); const QuicStreamId stream_id = GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); // Payload is the first byte of a two byte varint encoding. std::string payload = absl::HexStringToBytes("40"); @@ -2740,10 +2721,10 @@ TEST_P(QuicSpdySessionTestClient, FinInMiddleOfStreamType) { if (!VersionUsesHttp3(transport_version())) { return; } - ASSERT_TRUE(session_.UsesPendingStreams()); const QuicStreamId stream_id = GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 0); + ASSERT_TRUE(session_.UsesPendingStreamForFrame(STREAM_FRAME, stream_id)); // Payload is the first byte of a two byte varint encoding with a FIN. std::string payload = absl::HexStringToBytes("40"); @@ -2894,10 +2875,7 @@ TEST_P(QuicSpdySessionTestClient, Http3GoAwayLargerIdThanBefore) { session_.OnHttp3GoAway(stream_id2); } -// Test that receipt of CANCEL_PUSH frame does not result in closing the -// connection. -// TODO(b/151841240): Handle CANCEL_PUSH frames instead of ignoring them. -TEST_P(QuicSpdySessionTestClient, IgnoreCancelPush) { +TEST_P(QuicSpdySessionTestClient, CloseConnectionOnCancelPush) { if (!VersionUsesHttp3(transport_version())) { return; } @@ -2934,18 +2912,16 @@ TEST_P(QuicSpdySessionTestClient, IgnoreCancelPush) { "00"); // push ID QuicStreamFrame data3(receive_control_stream_id, /* fin = */ false, offset, cancel_push_frame); - EXPECT_CALL(debug_visitor, OnCancelPushFrameReceived(_)); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, + "CANCEL_PUSH frame received.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, + SendConnectionClosePacket(QUIC_HTTP_FRAME_ERROR, _, + "CANCEL_PUSH frame received.")); session_.OnStreamFrame(data3); } -TEST_P(QuicSpdySessionTestServer, ServerPushEnabledDefaultValue) { - if (VersionUsesHttp3(transport_version())) { - EXPECT_FALSE(session_.server_push_enabled()); - } else { - EXPECT_TRUE(session_.server_push_enabled()); - } -} - TEST_P(QuicSpdySessionTestServer, OnSetting) { CompleteHandshake(); if (VersionUsesHttp3(transport_version())) { @@ -2975,10 +2951,6 @@ TEST_P(QuicSpdySessionTestServer, OnSetting) { session_.OnSetting(SETTINGS_MAX_FIELD_SECTION_SIZE, 5); EXPECT_EQ(5u, session_.max_outbound_header_list_size()); - EXPECT_TRUE(session_.server_push_enabled()); - session_.OnSetting(spdy::SETTINGS_ENABLE_PUSH, 0); - EXPECT_FALSE(session_.server_push_enabled()); - spdy::HpackEncoder* hpack_encoder = QuicSpdySessionPeer::GetSpdyFramer(&session_)->GetHpackEncoder(); EXPECT_EQ(4096u, hpack_encoder->CurrentHeaderTableSizeSetting()); @@ -3105,10 +3077,7 @@ TEST_P(QuicSpdySessionTestServer, PeerClosesCriticalSendStream) { session_.OnStopSendingFrame(stop_sending_encoder_stream); } -// Test that receipt of CANCEL_PUSH frame does not result in closing the -// connection. -// TODO(b/151841240): Handle CANCEL_PUSH frames instead of ignoring them. -TEST_P(QuicSpdySessionTestServer, IgnoreCancelPush) { +TEST_P(QuicSpdySessionTestServer, CloseConnectionOnCancelPush) { if (!VersionUsesHttp3(transport_version())) { return; } @@ -3145,7 +3114,13 @@ TEST_P(QuicSpdySessionTestServer, IgnoreCancelPush) { "00"); // push ID QuicStreamFrame data3(receive_control_stream_id, /* fin = */ false, offset, cancel_push_frame); - EXPECT_CALL(debug_visitor, OnCancelPushFrameReceived(_)); + EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_FRAME_ERROR, + "CANCEL_PUSH frame received.", _)) + .WillOnce( + Invoke(connection_, &MockQuicConnection::ReallyCloseConnection)); + EXPECT_CALL(*connection_, + SendConnectionClosePacket(QUIC_HTTP_FRAME_ERROR, _, + "CANCEL_PUSH frame received.")); session_.OnStreamFrame(data3); } @@ -3469,36 +3444,32 @@ TEST_P(QuicSpdySessionTestClient, AlpsTwoSettingsFrame) { EXPECT_EQ("multiple SETTINGS frames", error.value()); } -TEST_P(QuicSpdySessionTestClient, GetNextDatagramFlowId) { +void QuicSpdySessionTestBase::TestHttpDatagramSetting( + HttpDatagramSupport local_support, HttpDatagramSupport remote_support, + HttpDatagramSupport expected_support, bool expected_datagram_supported) { if (!version().UsesHttp3()) { return; } - EXPECT_EQ(session_.GetNextDatagramFlowId(), 0u); - EXPECT_EQ(session_.GetNextDatagramFlowId(), 2u); - EXPECT_EQ(session_.GetNextDatagramFlowId(), 4u); - EXPECT_EQ(session_.GetNextDatagramFlowId(), 6u); -} - -TEST_P(QuicSpdySessionTestServer, GetNextDatagramFlowId) { - if (!version().UsesHttp3()) { - return; - } - EXPECT_EQ(session_.GetNextDatagramFlowId(), 1u); - EXPECT_EQ(session_.GetNextDatagramFlowId(), 3u); - EXPECT_EQ(session_.GetNextDatagramFlowId(), 5u); - EXPECT_EQ(session_.GetNextDatagramFlowId(), 7u); -} - -TEST_P(QuicSpdySessionTestClient, H3DatagramSetting) { - if (!version().UsesHttp3()) { - return; - } - SetQuicReloadableFlag(quic_h3_datagram, true); + session_.set_local_http_datagram_support(local_support); // HTTP/3 datagrams aren't supported before SETTINGS are received. - EXPECT_FALSE(session_.h3_datagram_supported()); + EXPECT_FALSE(session_.SupportsH3Datagram()); + EXPECT_EQ(session_.http_datagram_support(), HttpDatagramSupport::kNone); // Receive SETTINGS. SettingsFrame settings; - settings.values[SETTINGS_H3_DATAGRAM] = 1; + switch (remote_support) { + case HttpDatagramSupport::kNone: + break; + case HttpDatagramSupport::kDraft00: + settings.values[SETTINGS_H3_DATAGRAM_DRAFT00] = 1; + break; + case HttpDatagramSupport::kDraft04: + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + break; + case HttpDatagramSupport::kDraft00And04: + settings.values[SETTINGS_H3_DATAGRAM_DRAFT00] = 1; + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + break; + } std::string data = std::string(1, kControlStream) + EncodeSettings(settings); QuicStreamId stream_id = GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); @@ -3508,56 +3479,88 @@ TEST_P(QuicSpdySessionTestClient, H3DatagramSetting) { EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(stream_id)); EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(settings)); session_.OnStreamFrame(frame); - // HTTP/3 datagrams are now supported. - EXPECT_TRUE(session_.h3_datagram_supported()); + EXPECT_EQ(session_.http_datagram_support(), expected_support); + EXPECT_EQ(session_.SupportsH3Datagram(), expected_datagram_supported); } -TEST_P(QuicSpdySessionTestClient, H3DatagramRegistration) { - if (!version().UsesHttp3()) { - return; - } - CompleteHandshake(); - SetQuicReloadableFlag(quic_h3_datagram, true); - QuicSpdySessionPeer::SetH3DatagramSupported(&session_, true); - SavingHttp3DatagramVisitor h3_datagram_visitor; - QuicDatagramFlowId flow_id = session_.GetNextDatagramFlowId(); - ASSERT_EQ(QuicDataWriter::GetVarInt62Len(flow_id), 1); - uint8_t datagram[256]; - datagram[0] = flow_id; - for (size_t i = 1; i < ABSL_ARRAYSIZE(datagram); i++) { - datagram[i] = i; - } - session_.RegisterHttp3FlowId(flow_id, &h3_datagram_visitor); - session_.OnMessageReceived(absl::string_view( - reinterpret_cast(datagram), sizeof(datagram))); - EXPECT_THAT( - h3_datagram_visitor.received_h3_datagrams(), - ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ - flow_id, std::string(reinterpret_cast(datagram + 1), - sizeof(datagram) - 1)})); - session_.UnregisterHttp3FlowId(flow_id); -} - -TEST_P(QuicSpdySessionTestClient, SendHttp3Datagram) { - if (!version().UsesHttp3()) { - return; - } - CompleteHandshake(); - SetQuicReloadableFlag(quic_h3_datagram, true); - QuicSpdySessionPeer::SetH3DatagramSupported(&session_, true); - QuicDatagramFlowId flow_id = session_.GetNextDatagramFlowId(); - std::string h3_datagram_payload = {1, 2, 3, 4, 5, 6}; - EXPECT_CALL(*connection_, SendMessage(1, _, false)) - .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); - EXPECT_EQ(session_.SendHttp3Datagram(flow_id, h3_datagram_payload), - MESSAGE_STATUS_SUCCESS); +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal00Remote00) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft00, + /*remote_support=*/HttpDatagramSupport::kDraft00, + /*expected_support=*/HttpDatagramSupport::kDraft00, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal00Remote04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft00, + /*remote_support=*/HttpDatagramSupport::kDraft04, + /*expected_support=*/HttpDatagramSupport::kNone, + /*expected_datagram_supported=*/false); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal00Remote00And04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft00, + /*remote_support=*/HttpDatagramSupport::kDraft00And04, + /*expected_support=*/HttpDatagramSupport::kDraft00, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04Remote00) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft04, + /*remote_support=*/HttpDatagramSupport::kDraft00, + /*expected_support=*/HttpDatagramSupport::kNone, + /*expected_datagram_supported=*/false); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04Remote04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft04, + /*remote_support=*/HttpDatagramSupport::kDraft04, + /*expected_support=*/HttpDatagramSupport::kDraft04, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal04Remote00And04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft04, + /*remote_support=*/HttpDatagramSupport::kDraft00And04, + /*expected_support=*/HttpDatagramSupport::kDraft04, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal00And04Remote00) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft00And04, + /*remote_support=*/HttpDatagramSupport::kDraft00, + /*expected_support=*/HttpDatagramSupport::kDraft00, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, HttpDatagramSettingLocal00And04Remote04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft00And04, + /*remote_support=*/HttpDatagramSupport::kDraft04, + /*expected_support=*/HttpDatagramSupport::kDraft04, + /*expected_datagram_supported=*/true); +} + +TEST_P(QuicSpdySessionTestClient, + HttpDatagramSettingLocal00And04Remote00And04) { + TestHttpDatagramSetting( + /*local_support=*/HttpDatagramSupport::kDraft00And04, + /*remote_support=*/HttpDatagramSupport::kDraft00And04, + /*expected_support=*/HttpDatagramSupport::kDraft04, + /*expected_datagram_supported=*/true); } TEST_P(QuicSpdySessionTestClient, WebTransportSetting) { if (!version().UsesHttp3()) { return; } - SetQuicReloadableFlag(quic_h3_datagram, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_.set_supports_webtransport(true); EXPECT_FALSE(session_.SupportsWebTransport()); @@ -3569,17 +3572,10 @@ TEST_P(QuicSpdySessionTestClient, WebTransportSetting) { session_.set_debug_visitor(&debug_visitor); CompleteHandshake(); - SettingsFrame server_settings; - server_settings.values[SETTINGS_H3_DATAGRAM] = 1; - server_settings.values[SETTINGS_WEBTRANS_DRAFT00] = 1; - std::string data = - std::string(1, kControlStream) + EncodeSettings(server_settings); - QuicStreamId stream_id = - GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); - QuicStreamFrame frame(stream_id, /*fin=*/false, /*offset=*/0, data); - EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(stream_id)); - EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(server_settings)); - session_.OnStreamFrame(frame); + EXPECT_CALL(debug_visitor, OnPeerControlStreamCreated(_)); + EXPECT_CALL(debug_visitor, OnSettingsFrameReceived(_)); + ReceiveWebTransportSettings(); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); EXPECT_TRUE(session_.SupportsWebTransport()); } @@ -3587,7 +3583,7 @@ TEST_P(QuicSpdySessionTestClient, WebTransportSettingSetToZero) { if (!version().UsesHttp3()) { return; } - SetQuicReloadableFlag(quic_h3_datagram, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_.set_supports_webtransport(true); EXPECT_FALSE(session_.SupportsWebTransport()); @@ -3600,7 +3596,7 @@ TEST_P(QuicSpdySessionTestClient, WebTransportSettingSetToZero) { CompleteHandshake(); SettingsFrame server_settings; - server_settings.values[SETTINGS_H3_DATAGRAM] = 1; + server_settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; server_settings.values[SETTINGS_WEBTRANS_DRAFT00] = 0; std::string data = std::string(1, kControlStream) + EncodeSettings(server_settings); @@ -3617,7 +3613,7 @@ TEST_P(QuicSpdySessionTestServer, WebTransportSetting) { if (!version().UsesHttp3()) { return; } - SetQuicReloadableFlag(quic_h3_datagram, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_.set_supports_webtransport(true); EXPECT_FALSE(session_.SupportsWebTransport()); @@ -3634,7 +3630,7 @@ TEST_P(QuicSpdySessionTestServer, BufferingIncomingStreams) { if (!version().UsesHttp3()) { return; } - SetQuicReloadableFlag(quic_h3_datagram, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_.set_supports_webtransport(true); CompleteHandshake(); @@ -3667,7 +3663,7 @@ TEST_P(QuicSpdySessionTestServer, BufferingIncomingStreamsLimit) { if (!version().UsesHttp3()) { return; } - SetQuicReloadableFlag(quic_h3_datagram, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_.set_supports_webtransport(true); CompleteHandshake(); @@ -3708,7 +3704,7 @@ TEST_P(QuicSpdySessionTestServer, ResetOutgoingWebTransportStreams) { if (!version().UsesHttp3()) { return; } - SetQuicReloadableFlag(quic_h3_datagram, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_.set_supports_webtransport(true); CompleteHandshake(); @@ -3739,6 +3735,91 @@ TEST_P(QuicSpdySessionTestServer, ResetOutgoingWebTransportStreams) { EXPECT_EQ(web_transport->NumberOfAssociatedStreams(), 0u); } +TEST_P(QuicSpdySessionTestClient, WebTransportWithoutExtendedConnect) { + if (!version().UsesHttp3()) { + return; + } + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + session_.set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + session_.set_supports_webtransport(true); + + EXPECT_FALSE(session_.SupportsWebTransport()); + CompleteHandshake(); + + SettingsFrame settings; + settings.values[SETTINGS_H3_DATAGRAM_DRAFT04] = 1; + settings.values[SETTINGS_WEBTRANS_DRAFT00] = 1; + // No SETTINGS_ENABLE_CONNECT_PROTOCOL here. + std::string data = std::string(1, kControlStream) + EncodeSettings(settings); + QuicStreamId control_stream_id = + session_.perspective() == Perspective::IS_SERVER + ? GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3) + : GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + QuicStreamFrame frame(control_stream_id, /*fin=*/false, /*offset=*/0, data); + session_.OnStreamFrame(frame); + + EXPECT_TRUE(session_.SupportsWebTransport()); +} + +class QuicSpdySessionTestServerNoExtendedConnect + : public QuicSpdySessionTestBase { + public: + QuicSpdySessionTestServerNoExtendedConnect() + : QuicSpdySessionTestBase(Perspective::IS_SERVER, false) {} +}; + +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdySessionTestServerNoExtendedConnect, + ::testing::ValuesIn(AllSupportedVersions()), + ::testing::PrintToStringParamName()); + +// Tests that receiving SETTINGS_ENABLE_CONNECT_PROTOCOL = 1 doesn't enable +// server session to support extended CONNECT. +TEST_P(QuicSpdySessionTestServerNoExtendedConnect, + WebTransportSettingNoEffect) { + if (!version().UsesHttp3()) { + return; + } + + EXPECT_FALSE(session_.SupportsWebTransport()); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); + + CompleteHandshake(); + + ReceiveWebTransportSettings(); + EXPECT_FALSE(session_.allow_extended_connect()); + EXPECT_FALSE(session_.SupportsWebTransport()); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); +} + +TEST_P(QuicSpdySessionTestServerNoExtendedConnect, BadExtendedConnectSetting) { + if (!version().UsesHttp3()) { + return; + } + SetQuicReloadableFlag(quic_verify_request_headers_2, true); + SetQuicReloadableFlag(quic_act_upon_invalid_header, true); + + EXPECT_FALSE(session_.SupportsWebTransport()); + EXPECT_TRUE(session_.ShouldProcessIncomingRequests()); + + CompleteHandshake(); + + // ENABLE_CONNECT_PROTOCOL setting value has to be 1 or 0; + SettingsFrame settings; + settings.values[SETTINGS_ENABLE_CONNECT_PROTOCOL] = 2; + std::string data = std::string(1, kControlStream) + EncodeSettings(settings); + QuicStreamId control_stream_id = + session_.perspective() == Perspective::IS_SERVER + ? GetNthClientInitiatedUnidirectionalStreamId(transport_version(), 3) + : GetNthServerInitiatedUnidirectionalStreamId(transport_version(), 3); + QuicStreamFrame frame(control_stream_id, /*fin=*/false, /*offset=*/0, data); + EXPECT_CALL(*connection_, + CloseConnection(QUIC_HTTP_INVALID_SETTING_VALUE, _, _)); + EXPECT_QUIC_PEER_BUG( + session_.OnStreamFrame(frame), + "Received SETTINGS_ENABLE_CONNECT_PROTOCOL with invalid value"); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_stream.cc b/gquiche/quic/core/http/quic_spdy_stream.cc index 6da40675..e27e5192 100644 --- a/gquiche/quic/core/http/quic_spdy_stream.cc +++ b/gquiche/quic/core/http/quic_spdy_stream.cc @@ -13,24 +13,27 @@ #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "gquiche/http2/http2_constants.h" +#include "gquiche/quic/core/http/capsule.h" #include "gquiche/quic/core/http/http_constants.h" #include "gquiche/quic/core/http/http_decoder.h" +#include "gquiche/quic/core/http/http_frames.h" #include "gquiche/quic/core/http/quic_spdy_session.h" #include "gquiche/quic/core/http/spdy_utils.h" #include "gquiche/quic/core/http/web_transport_http3.h" #include "gquiche/quic/core/qpack/qpack_decoder.h" #include "gquiche/quic/core/qpack/qpack_encoder.h" -#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/core/quic_write_blocked_list.h" +#include "gquiche/quic/core/web_transport_interface.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_mem_slice_storage.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" #include "gquiche/spdy/core/spdy_protocol.h" using spdy::SpdyHeaderBlock; @@ -50,11 +53,6 @@ class QuicSpdyStream::HttpDecoderVisitor : public HttpDecoder::Visitor { stream_->OnUnrecoverableError(decoder->error(), decoder->error_detail()); } - bool OnCancelPushFrame(const CancelPushFrame& /*frame*/) override { - CloseConnectionOnWrongFrame("Cancel Push"); - return false; - } - bool OnMaxPushIdFrame(const MaxPushIdFrame& /*frame*/) override { CloseConnectionOnWrongFrame("Max Push Id"); return false; @@ -113,42 +111,6 @@ class QuicSpdyStream::HttpDecoderVisitor : public HttpDecoder::Visitor { return stream_->OnHeadersFrameEnd(); } - bool OnPushPromiseFrameStart(QuicByteCount header_length) override { - if (!VersionUsesHttp3(stream_->transport_version())) { - CloseConnectionOnWrongFrame("Push Promise"); - return false; - } - return stream_->OnPushPromiseFrameStart(header_length); - } - - bool OnPushPromiseFramePushId(PushId push_id, - QuicByteCount push_id_length, - QuicByteCount header_block_length) override { - if (!VersionUsesHttp3(stream_->transport_version())) { - CloseConnectionOnWrongFrame("Push Promise"); - return false; - } - return stream_->OnPushPromiseFramePushId(push_id, push_id_length, - header_block_length); - } - - bool OnPushPromiseFramePayload(absl::string_view payload) override { - QUICHE_DCHECK(!payload.empty()); - if (!VersionUsesHttp3(stream_->transport_version())) { - CloseConnectionOnWrongFrame("Push Promise"); - return false; - } - return stream_->OnPushPromiseFramePayload(payload); - } - - bool OnPushPromiseFrameEnd() override { - if (!VersionUsesHttp3(stream_->transport_version())) { - CloseConnectionOnWrongFrame("Push Promise"); - return false; - } - return stream_->OnPushPromiseFrameEnd(); - } - bool OnPriorityUpdateFrameStart(QuicByteCount /*header_length*/) override { CloseConnectionOnWrongFrame("Priority update"); return false; @@ -170,13 +132,11 @@ class QuicSpdyStream::HttpDecoderVisitor : public HttpDecoder::Visitor { } void OnWebTransportStreamFrameType( - QuicByteCount header_length, - WebTransportSessionId session_id) override { + QuicByteCount header_length, WebTransportSessionId session_id) override { stream_->OnWebTransportStreamFrameType(header_length, session_id); } - bool OnUnknownFrameStart(uint64_t frame_type, - QuicByteCount header_length, + bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length, QuicByteCount payload_length) override { return stream_->OnUnknownFrameStart(frame_type, header_length, payload_length); @@ -213,8 +173,7 @@ HttpDecoder::Options HttpDecoderOptionsForBidiStream( } } // namespace -QuicSpdyStream::QuicSpdyStream(QuicStreamId id, - QuicSpdySession* spdy_session, +QuicSpdyStream::QuicSpdyStream(QuicStreamId id, QuicSpdySession* spdy_session, StreamType type) : QuicStream(id, spdy_session, /*is_static=*/false, type), spdy_session_(spdy_session), @@ -232,7 +191,11 @@ QuicSpdyStream::QuicSpdyStream(QuicStreamId id, sequencer_offset_(0), is_decoder_processing_input_(false), ack_listener_(nullptr), - last_sent_urgency_(kDefaultUrgency) { + last_sent_urgency_(kDefaultUrgency), + datagram_next_available_context_id_(spdy_session->perspective() == + Perspective::IS_SERVER + ? kFirstDatagramContextIdServer + : kFirstDatagramContextIdClient) { QUICHE_DCHECK_EQ(session()->connection(), spdy_session->connection()); QUICHE_DCHECK_EQ(transport_version(), spdy_session->transport_version()); QUICHE_DCHECK(!QuicUtils::IsCryptoStreamId(transport_version(), id)); @@ -251,9 +214,8 @@ QuicSpdyStream::QuicSpdyStream(QuicStreamId id, } QuicSpdyStream::QuicSpdyStream(PendingStream* pending, - QuicSpdySession* spdy_session, - StreamType type) - : QuicStream(pending, spdy_session, type, /*is_static=*/false), + QuicSpdySession* spdy_session) + : QuicStream(pending, spdy_session, /*is_static=*/false), spdy_session_(spdy_session), on_body_available_called_because_sequencer_is_closed_(false), visitor_(nullptr), @@ -287,9 +249,15 @@ QuicSpdyStream::QuicSpdyStream(PendingStream* pending, QuicSpdyStream::~QuicSpdyStream() {} +bool QuicSpdyStream::ShouldUseDatagramContexts() const { + return spdy_session_->SupportsH3Datagram() && + spdy_session_->http_datagram_support() != + HttpDatagramSupport::kDraft00 && + use_datagram_contexts_; +} + size_t QuicSpdyStream::WriteHeaders( - SpdyHeaderBlock header_block, - bool fin, + SpdyHeaderBlock header_block, bool fin, QuicReferenceCountedPointer ack_listener) { if (!AssertNotWebTransportDataStream("writing headers")) { return 0; @@ -315,6 +283,18 @@ size_t QuicSpdyStream::WriteHeaders( MaybeProcessSentWebTransportHeaders(header_block); + if (ShouldUseDatagramContexts()) { + // RegisterHttp3DatagramRegistrationVisitor caller wishes to use contexts, + // inform the peer. + header_block["sec-use-datagram-contexts"] = "?1"; + } + + if (web_transport_ != nullptr && + spdy_session_->http_datagram_support() != HttpDatagramSupport::kDraft00 && + spdy_session_->perspective() == Perspective::IS_SERVER) { + header_block["sec-webtransport-http3-draft"] = "draft02"; + } + size_t bytes_written = WriteHeadersImpl(std::move(header_block), fin, std::move(ack_listener)); if (!VersionUsesHttp3(transport_version()) && fin) { @@ -325,6 +305,16 @@ size_t QuicSpdyStream::WriteHeaders( SetFinSent(); CloseWriteSide(); } + + if (web_transport_ != nullptr && + session()->perspective() == Perspective::IS_CLIENT) { + // This will send a capsule and therefore needs to happen after headers have + // been sent. + RegisterHttp3DatagramContextId( + web_transport_->context_id(), DatagramFormatType::WEBTRANSPORT, + /*format_additional_data=*/absl::string_view(), web_transport_.get()); + } + return bytes_written; } @@ -342,18 +332,9 @@ void QuicSpdyStream::WriteOrBufferBody(absl::string_view data, bool fin) { spdy_session_->debug_visitor()->OnDataFrameSent(id(), data.length()); } - // Write frame header. - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(data.length(), &buffer); - unacked_frame_headers_offsets_.Add( - send_buffer().stream_offset(), - send_buffer().stream_offset() + header_length); - QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() - << " is writing DATA frame header of length " - << header_length; - WriteOrBufferData(absl::string_view(buffer.get(), header_length), false, - nullptr); + const bool success = + WriteDataFrameHeader(data.length(), /*force_write=*/true); + QUICHE_DCHECK(success); // Write body. QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() @@ -403,8 +384,7 @@ size_t QuicSpdyStream::WriteTrailers( return bytes_written; } -QuicConsumedData QuicSpdyStream::WritevBody(const struct iovec* iov, - int count, +QuicConsumedData QuicSpdyStream::WritevBody(const struct iovec* iov, int count, bool fin) { QuicMemSliceStorage storage( iov, count, @@ -413,44 +393,53 @@ QuicConsumedData QuicSpdyStream::WritevBody(const struct iovec* iov, return WriteBodySlices(storage.ToSpan(), fin); } -QuicConsumedData QuicSpdyStream::WriteBodySlices(QuicMemSliceSpan slices, - bool fin) { - if (!VersionUsesHttp3(transport_version()) || slices.empty()) { - return WriteMemSlices(slices, fin); - } - - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(slices.total_length(), &buffer); - if (!CanWriteNewDataAfterData(header_length)) { - return {0, false}; +bool QuicSpdyStream::WriteDataFrameHeader(QuicByteCount data_length, + bool force_write) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + QUICHE_DCHECK_GT(data_length, 0u); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + data_length, + spdy_session_->connection()->helper()->GetStreamSendBufferAllocator()); + const bool can_write = CanWriteNewDataAfterData(header.size()); + if (!can_write && !force_write) { + return false; } if (spdy_session_->debug_visitor()) { - spdy_session_->debug_visitor()->OnDataFrameSent(id(), - slices.total_length()); + spdy_session_->debug_visitor()->OnDataFrameSent(id(), data_length); } - QuicConnection::ScopedPacketFlusher flusher(spdy_session_->connection()); - - // Write frame header. - struct iovec header_iov = {static_cast(buffer.get()), header_length}; - QuicMemSliceStorage storage( - &header_iov, 1, - spdy_session_->connection()->helper()->GetStreamSendBufferAllocator(), - GetQuicFlag(FLAGS_quic_send_buffer_max_data_slice_size)); unacked_frame_headers_offsets_.Add( send_buffer().stream_offset(), - send_buffer().stream_offset() + header_length); + send_buffer().stream_offset() + header.size()); QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() << " is writing DATA frame header of length " - << header_length; - WriteMemSlices(storage.ToSpan(), false); + << header.size(); + if (can_write) { + // Save one copy and allocation if send buffer can accomodate the header. + QuicMemSlice header_slice(std::move(header)); + WriteMemSlices(absl::MakeSpan(&header_slice, 1), false); + } else { + QUICHE_DCHECK(force_write); + WriteOrBufferData(header.AsStringView(), false, nullptr); + } + return true; +} + +QuicConsumedData QuicSpdyStream::WriteBodySlices( + absl::Span slices, bool fin) { + if (!VersionUsesHttp3(transport_version()) || slices.empty()) { + return WriteMemSlices(slices, fin); + } + + QuicConnection::ScopedPacketFlusher flusher(spdy_session_->connection()); + const QuicByteCount data_size = MemSliceSpanTotalSize(slices); + if (!WriteDataFrameHeader(data_size, /*force_write=*/false)) { + return {0, false}; + } - // Write body. QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() - << " is writing DATA frame payload of length " - << slices.total_length(); + << " is writing DATA frame payload of length " << data_size; return WriteMemSlices(slices, fin); } @@ -497,9 +486,7 @@ bool QuicSpdyStream::HasBytesToRead() const { return body_manager_.HasBytesToRead(); } -void QuicSpdyStream::MarkTrailersConsumed() { - trailers_consumed_ = true; -} +void QuicSpdyStream::MarkTrailersConsumed() { trailers_consumed_ = true; } uint64_t QuicSpdyStream::total_body_bytes_read() const { if (VersionUsesHttp3(transport_version())) { @@ -521,14 +508,14 @@ void QuicSpdyStream::ConsumeHeaderList() { } if (body_manager_.HasBytesToRead()) { - OnBodyAvailable(); + HandleBodyAvailable(); return; } if (sequencer()->IsClosed() && !on_body_available_called_because_sequencer_is_closed_) { on_body_available_called_because_sequencer_is_closed_ = true; - OnBodyAvailable(); + HandleBodyAvailable(); } } @@ -539,8 +526,7 @@ void QuicSpdyStream::OnStreamHeadersPriority( SetPriority(precedence); } -void QuicSpdyStream::OnStreamHeaderList(bool fin, - size_t frame_len, +void QuicSpdyStream::OnStreamHeaderList(bool fin, size_t frame_len, const QuicHeaderList& header_list) { if (!spdy_session()->user_agent_id().has_value()) { std::string uaid; @@ -593,10 +579,6 @@ void QuicSpdyStream::OnHeadersDecoded(QuicHeaderList headers, OnStreamHeaderList(/* fin = */ false, headers_payload_length_, headers); } else { - if (debug_visitor) { - debug_visitor->OnPushPromiseDecoded(id(), promised_stream_id, headers); - } - spdy_session_->OnHeaderList(headers); } @@ -607,14 +589,14 @@ void QuicSpdyStream::OnHeadersDecoded(QuicHeaderList headers, } } -void QuicSpdyStream::OnHeaderDecodingError(absl::string_view error_message) { +void QuicSpdyStream::OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) { qpack_decoded_headers_accumulator_.reset(); std::string connection_close_error_message = absl::StrCat( "Error decoding ", headers_decompressed_ ? "trailers" : "headers", " on stream ", id(), ": ", error_message); - OnUnrecoverableError(QUIC_QPACK_DECOMPRESSION_FAILED, - connection_close_error_message); + OnUnrecoverableError(error_code, connection_close_error_message); } void QuicSpdyStream::MaybeSendPriorityUpdateFrame() { @@ -637,19 +619,45 @@ void QuicSpdyStream::MaybeSendPriorityUpdateFrame() { spdy_session_->WriteHttp3PriorityUpdate(priority_update); } -void QuicSpdyStream::OnHeadersTooLarge() { - Reset(QUIC_HEADERS_TOO_LARGE); -} +void QuicSpdyStream::OnHeadersTooLarge() { Reset(QUIC_HEADERS_TOO_LARGE); } void QuicSpdyStream::OnInitialHeadersComplete( - bool fin, - size_t /*frame_len*/, - const QuicHeaderList& header_list) { + bool fin, size_t /*frame_len*/, const QuicHeaderList& header_list) { // TODO(b/134706391): remove |fin| argument. headers_decompressed_ = true; header_list_ = header_list; - - MaybeProcessReceivedWebTransportHeaders(); + bool header_too_large = VersionUsesHttp3(transport_version()) + ? header_list_size_limit_exceeded_ + : header_list.empty(); + // Validate request headers if it did not exceed size limit. If it did, + // OnHeadersTooLarge() should have already handled it previously. + if (!header_too_large && !AreHeadersValid(header_list)) { + QUIC_CODE_COUNT_N(quic_validate_request_header, 1, 2); + if (GetQuicReloadableFlag(quic_act_upon_invalid_header)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_act_upon_invalid_header); + OnInvalidHeaders(); + return; + } + } + QUIC_CODE_COUNT_N(quic_validate_request_header, 2, 2); + + if (!GetQuicReloadableFlag(quic_verify_request_headers_2) || + !header_too_large) { + MaybeProcessReceivedWebTransportHeaders(); + if (ShouldUseDatagramContexts()) { + bool peer_wishes_to_use_datagram_contexts = false; + for (const auto& header : header_list_) { + if (header.first == "sec-use-datagram-contexts" && + header.second == "?1") { + peer_wishes_to_use_datagram_contexts = true; + break; + } + } + if (!peer_wishes_to_use_datagram_contexts) { + use_datagram_contexts_ = false; + } + } + } if (VersionUsesHttp3(transport_version())) { if (fin) { @@ -670,8 +678,7 @@ void QuicSpdyStream::OnInitialHeadersComplete( } void QuicSpdyStream::OnPromiseHeaderList( - QuicStreamId /* promised_id */, - size_t /* frame_len */, + QuicStreamId /* promised_id */, size_t /* frame_len */, const QuicHeaderList& /*header_list */) { // To be overridden in QuicSpdyClientStream. Not supported on // server side. @@ -680,9 +687,7 @@ void QuicSpdyStream::OnPromiseHeaderList( } void QuicSpdyStream::OnTrailingHeadersComplete( - bool fin, - size_t /*frame_len*/, - const QuicHeaderList& header_list) { + bool fin, size_t /*frame_len*/, const QuicHeaderList& header_list) { // TODO(b/134706391): remove |fin| argument. QUICHE_DCHECK(!trailers_decompressed_); if (!VersionUsesHttp3(transport_version()) && fin_received()) { @@ -730,6 +735,12 @@ void QuicSpdyStream::OnPriorityFrame( void QuicSpdyStream::OnStreamReset(const QuicRstStreamFrame& frame) { if (web_transport_data_ != nullptr) { + WebTransportStreamVisitor* webtransport_visitor = + web_transport_data_->adapter.visitor(); + if (webtransport_visitor != nullptr) { + webtransport_visitor->OnResetStreamReceived( + Http3ErrorToWebTransportOrDefault(frame.ietf_error_code)); + } QuicStream::OnStreamReset(frame); return; } @@ -768,11 +779,11 @@ void QuicSpdyStream::OnStreamReset(const QuicRstStreamFrame& frame) { << "Received QUIC_STREAM_NO_ERROR, not discarding response"; set_rst_received(true); MaybeIncreaseHighestReceivedOffset(frame.byte_offset); - set_stream_error(frame.error_code); + set_stream_error(frame.error()); CloseWriteSide(); } -void QuicSpdyStream::Reset(QuicRstStreamErrorCode error) { +void QuicSpdyStream::ResetWithError(QuicResetStreamError error) { if (VersionUsesHttp3(transport_version()) && !fin_received() && spdy_session_->qpack_decoder() && web_transport_data_ == nullptr) { QUIC_CODE_COUNT_N(quic_abort_qpack_on_stream_reset, 2, 2); @@ -783,7 +794,30 @@ void QuicSpdyStream::Reset(QuicRstStreamErrorCode error) { } } - QuicStream::Reset(error); + QuicStream::ResetWithError(error); +} + +bool QuicSpdyStream::OnStopSending(QuicResetStreamError error) { + if (web_transport_data_ != nullptr) { + WebTransportStreamVisitor* visitor = web_transport_data_->adapter.visitor(); + if (visitor != nullptr) { + visitor->OnStopSendingReceived( + Http3ErrorToWebTransportOrDefault(error.ietf_application_code())); + } + } + + return QuicStream::OnStopSending(error); +} + +void QuicSpdyStream::OnWriteSideInDataRecvdState() { + if (web_transport_data_ != nullptr) { + WebTransportStreamVisitor* visitor = web_transport_data_->adapter.visitor(); + if (visitor != nullptr) { + visitor->OnWriteSideInDataRecvdState(); + } + } + + QuicStream::OnWriteSideInDataRecvdState(); } void QuicSpdyStream::OnDataAvailable() { @@ -793,7 +827,7 @@ void QuicSpdyStream::OnDataAvailable() { } if (!VersionUsesHttp3(transport_version())) { - OnBodyAvailable(); + HandleBodyAvailable(); return; } @@ -838,20 +872,20 @@ void QuicSpdyStream::OnDataAvailable() { } } - // Do not call OnBodyAvailable() until headers are consumed. + // Do not call HandleBodyAvailable() until headers are consumed. if (!FinishedReadingHeaders()) { return; } if (body_manager_.HasBytesToRead()) { - OnBodyAvailable(); + HandleBodyAvailable(); return; } if (sequencer()->IsClosed() && !on_body_available_called_because_sequencer_is_closed_) { on_body_available_called_because_sequencer_is_closed_ = true; - OnBodyAvailable(); + HandleBodyAvailable(); } } @@ -868,8 +902,12 @@ void QuicSpdyStream::OnClose() { visitor->OnClose(this); } + if (datagram_flow_id_.has_value()) { + spdy_session_->UnregisterHttp3DatagramFlowId(datagram_flow_id_.value()); + } + if (web_transport_ != nullptr) { - web_transport_->CloseAllAssociatedStreams(); + web_transport_->OnConnectStreamClosing(); } if (web_transport_data_ != nullptr) { WebTransportHttp3* web_transport = @@ -1009,8 +1047,7 @@ void QuicSpdyStream::OnStreamFrameRetransmitted(QuicStreamOffset offset, } QuicByteCount QuicSpdyStream::GetNumFrameHeadersInInterval( - QuicStreamOffset offset, - QuicByteCount data_length) const { + QuicStreamOffset offset, QuicByteCount data_length) const { QuicByteCount header_acked_length = 0; QuicIntervalSet newly_acked(offset, offset + data_length); newly_acked.Intersection(unacked_frame_headers_offsets_); @@ -1080,53 +1117,8 @@ bool QuicSpdyStream::OnHeadersFrameEnd() { return !sequencer()->IsClosed() && !reading_stopped(); } -bool QuicSpdyStream::OnPushPromiseFrameStart(QuicByteCount header_length) { - QUICHE_DCHECK(VersionUsesHttp3(transport_version())); - QUICHE_DCHECK(!qpack_decoded_headers_accumulator_); - - sequencer()->MarkConsumed(body_manager_.OnNonBody(header_length)); - - return true; -} - -bool QuicSpdyStream::OnPushPromiseFramePushId( - PushId push_id, - QuicByteCount push_id_length, - QuicByteCount header_block_length) { - QUICHE_DCHECK(VersionUsesHttp3(transport_version())); - QUICHE_DCHECK(!qpack_decoded_headers_accumulator_); - - if (spdy_session_->debug_visitor()) { - spdy_session_->debug_visitor()->OnPushPromiseFrameReceived( - id(), push_id, header_block_length); - } - - // TODO(b/151749109): Check max push id and handle errors. - spdy_session_->OnPushPromise(id(), push_id); - sequencer()->MarkConsumed(body_manager_.OnNonBody(push_id_length)); - - qpack_decoded_headers_accumulator_ = - std::make_unique( - id(), spdy_session_->qpack_decoder(), this, - spdy_session_->max_inbound_header_list_size()); - - return true; -} - -bool QuicSpdyStream::OnPushPromiseFramePayload(absl::string_view payload) { - spdy_session_->OnCompressedFrameSize(payload.length()); - return OnHeadersFramePayload(payload); -} - -bool QuicSpdyStream::OnPushPromiseFrameEnd() { - QUICHE_DCHECK(VersionUsesHttp3(transport_version())); - - return OnHeadersFrameEnd(); -} - void QuicSpdyStream::OnWebTransportStreamFrameType( - QuicByteCount header_length, - WebTransportSessionId session_id) { + QuicByteCount header_length, WebTransportSessionId session_id) { QUIC_DVLOG(1) << ENDPOINT << " Received WEBTRANSPORT_STREAM on stream " << id() << " for session " << session_id; sequencer()->MarkConsumed(header_length); @@ -1178,13 +1170,10 @@ bool QuicSpdyStream::OnUnknownFramePayload(absl::string_view payload) { return true; } -bool QuicSpdyStream::OnUnknownFrameEnd() { - return true; -} +bool QuicSpdyStream::OnUnknownFrameEnd() { return true; } size_t QuicSpdyStream::WriteHeadersImpl( - spdy::SpdyHeaderBlock header_block, - bool fin, + spdy::SpdyHeaderBlock header_block, bool fin, QuicReferenceCountedPointer ack_listener) { if (!VersionUsesHttp3(transport_version())) { return spdy_session_->WriteHeadersOnHeadersStream( @@ -1232,6 +1221,16 @@ size_t QuicSpdyStream::WriteHeadersImpl( return encoded_headers.size(); } +bool QuicSpdyStream::CanWriteNewBodyData(QuicByteCount write_size) const { + QUICHE_DCHECK_NE(0u, write_size); + if (!VersionUsesHttp3(transport_version())) { + return CanWriteNewData(); + } + + return CanWriteNewDataAfterData( + HttpEncoder::GetDataFrameHeaderLength(write_size)); +} + void QuicSpdyStream::MaybeProcessReceivedWebTransportHeaders() { if (!spdy_session_->SupportsWebTransport()) { return; @@ -1243,7 +1242,8 @@ void QuicSpdyStream::MaybeProcessReceivedWebTransportHeaders() { std::string method; std::string protocol; - absl::optional flow_id; + absl::optional flow_id; + bool version_indicated = false; for (const auto& header : header_list_) { const std::string& header_name = header.first; const std::string& header_value = header.second; @@ -1260,24 +1260,69 @@ void QuicSpdyStream::MaybeProcessReceivedWebTransportHeaders() { protocol = header_value; } if (header_name == "datagram-flow-id") { + if (spdy_session_->http_datagram_support() != + HttpDatagramSupport::kDraft00) { + QUIC_DLOG(ERROR) << ENDPOINT + << "Rejecting WebTransport due to unexpected " + "Datagram-Flow-Id header"; + return; + } if (flow_id.has_value() || header_value.empty()) { return; } - QuicDatagramFlowId flow_id_out; + QuicDatagramStreamId flow_id_out; if (!absl::SimpleAtoi(header_value, &flow_id_out)) { return; } flow_id = flow_id_out; } + if (header_name == "sec-webtransport-http3-draft02") { + if (header_value != "1") { + QUIC_DLOG(ERROR) << ENDPOINT + << "Rejecting WebTransport due to invalid value of " + "Sec-Webtransport-Http3-Draft02 header"; + return; + } + version_indicated = true; + } } - if (method != "CONNECT" || protocol != "webtransport" || - !flow_id.has_value()) { + if (method != "CONNECT" || protocol != "webtransport") { return; } - web_transport_ = - std::make_unique(spdy_session_, this, id(), *flow_id); + if (!version_indicated && + spdy_session_->http_datagram_support() != HttpDatagramSupport::kDraft00) { + QUIC_DLOG(ERROR) + << ENDPOINT + << "WebTransport request rejected due to missing version header."; + return; + } + + if (spdy_session_->http_datagram_support() == HttpDatagramSupport::kDraft00) { + if (!flow_id.has_value()) { + QUIC_DLOG(ERROR) + << ENDPOINT + << "Rejecting WebTransport due to missing Datagram-Flow-Id header"; + return; + } + RegisterHttp3DatagramFlowId(*flow_id); + } + + web_transport_ = std::make_unique( + spdy_session_, this, id(), + spdy_session_->ShouldNegotiateDatagramContexts()); + + if (spdy_session_->http_datagram_support() != HttpDatagramSupport::kDraft00) { + return; + } + // If we're in draft-ietf-masque-h3-datagram-00 mode, pretend we also received + // a REGISTER_DATAGRAM_NO_CONTEXT capsule. + // TODO(b/181256914) remove this when we remove support for + // draft-ietf-masque-h3-datagram-00 in favor of later drafts. + RegisterHttp3DatagramContextId( + /*context_id=*/absl::nullopt, DatagramFormatType::WEBTRANSPORT, + /*format_additional_data=*/absl::string_view(), web_transport_.get()); } void QuicSpdyStream::MaybeProcessSentWebTransportHeaders( @@ -1299,11 +1344,15 @@ void QuicSpdyStream::MaybeProcessSentWebTransportHeaders( return; } - QuicDatagramFlowId flow_id = spdy_session_->GetNextDatagramFlowId(); - headers["datagram-flow-id"] = absl::StrCat(flow_id); + if (spdy_session_->http_datagram_support() == HttpDatagramSupport::kDraft00) { + headers["datagram-flow-id"] = absl::StrCat(id()); + } else { + headers["sec-webtransport-http3-draft02"] = "1"; + } - web_transport_ = - std::make_unique(spdy_session_, this, id(), flow_id); + web_transport_ = std::make_unique( + spdy_session_, this, id(), + spdy_session_->ShouldNegotiateDatagramContexts()); } void QuicSpdyStream::OnCanWriteNewData() { @@ -1359,10 +1408,504 @@ void QuicSpdyStream::ConvertToWebTransportDataStream( } QuicSpdyStream::WebTransportDataStream::WebTransportDataStream( - QuicSpdyStream* stream, - WebTransportSessionId session_id) + QuicSpdyStream* stream, WebTransportSessionId session_id) : session_id(session_id), adapter(stream->spdy_session_, stream, stream->sequencer()) {} +void QuicSpdyStream::HandleReceivedDatagram( + absl::optional context_id, + absl::string_view payload) { + Http3DatagramVisitor* visitor; + if (context_id.has_value()) { + auto it = datagram_context_visitors_.find(context_id.value()); + if (it == datagram_context_visitors_.end()) { + QUIC_DLOG(ERROR) << ENDPOINT + << "Received datagram without any visitor for context " + << context_id.value(); + return; + } + visitor = it->second; + } else { + if (datagram_no_context_visitor_ == nullptr) { + QUIC_DLOG(ERROR) + << ENDPOINT << "Received datagram without any visitor for no context"; + return; + } + visitor = datagram_no_context_visitor_; + } + visitor->OnHttp3Datagram(id(), context_id, payload); +} + +bool QuicSpdyStream::OnCapsule(const Capsule& capsule) { + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() << " received capsule " + << capsule; + if (!headers_decompressed_) { + QUIC_PEER_BUG(capsule before headers) + << ENDPOINT << "Stream " << id() << " received capsule " << capsule + << " before headers"; + return false; + } + if (web_transport_ != nullptr && web_transport_->close_received()) { + QUIC_PEER_BUG(capsule after close) + << ENDPOINT << "Stream " << id() << " received capsule " << capsule + << " after CLOSE_WEBTRANSPORT_SESSION."; + return false; + } + switch (capsule.capsule_type()) { + case CapsuleType::LEGACY_DATAGRAM: { + HandleReceivedDatagram( + capsule.legacy_datagram_capsule().context_id, + capsule.legacy_datagram_capsule().http_datagram_payload); + } break; + case CapsuleType::DATAGRAM_WITH_CONTEXT: { + HandleReceivedDatagram( + capsule.datagram_with_context_capsule().context_id, + capsule.datagram_with_context_capsule().http_datagram_payload); + } break; + case CapsuleType::DATAGRAM_WITHOUT_CONTEXT: { + absl::optional context_id; + if (use_datagram_contexts_) { + // draft-ietf-masque-h3-datagram-05 encodes context ID 0 using + // DATAGRAM_WITHOUT_CONTEXT. + context_id = 0; + } + HandleReceivedDatagram( + context_id, + capsule.datagram_without_context_capsule().http_datagram_payload); + } break; + case CapsuleType::REGISTER_DATAGRAM_CONTEXT: { + if (datagram_registration_visitor_ == nullptr) { + QUIC_DLOG(ERROR) << ENDPOINT << "Received capsule " << capsule + << " without any registration visitor"; + return false; + } + datagram_registration_visitor_->OnContextReceived( + id(), capsule.register_datagram_context_capsule().context_id, + capsule.register_datagram_context_capsule().format_type, + capsule.register_datagram_context_capsule().format_additional_data); + } break; + case CapsuleType::REGISTER_DATAGRAM_NO_CONTEXT: { + if (datagram_registration_visitor_ == nullptr) { + QUIC_DLOG(ERROR) << ENDPOINT << "Received capsule " << capsule + << " without any registration visitor"; + return false; + } + absl::optional context_id; + if (use_datagram_contexts_) { + // draft-ietf-masque-h3-datagram-05 encodes context ID 0 using + // REGISTER_DATAGRAM_NO_CONTEXT. + context_id = 0; + } + datagram_registration_visitor_->OnContextReceived( + id(), context_id, + capsule.register_datagram_no_context_capsule().format_type, + capsule.register_datagram_no_context_capsule() + .format_additional_data); + } break; + case CapsuleType::CLOSE_DATAGRAM_CONTEXT: { + if (datagram_registration_visitor_ == nullptr) { + QUIC_DLOG(ERROR) << ENDPOINT << "Received capsule " << capsule + << " without any registration visitor"; + return false; + } + datagram_registration_visitor_->OnContextClosed( + id(), capsule.close_datagram_context_capsule().context_id, + capsule.close_datagram_context_capsule().close_code, + capsule.close_datagram_context_capsule().close_details); + } break; + case CapsuleType::CLOSE_WEBTRANSPORT_SESSION: { + if (web_transport_ == nullptr) { + QUIC_DLOG(ERROR) << ENDPOINT << "Received capsule " << capsule + << " for a non-WebTransport stream."; + return false; + } + web_transport_->OnCloseReceived( + capsule.close_web_transport_session_capsule().error_code, + capsule.close_web_transport_session_capsule().error_message); + } break; + } + return true; +} + +void QuicSpdyStream::OnCapsuleParseFailure(const std::string& error_message) { + QUIC_DLOG(ERROR) << ENDPOINT << "Capsule parse failure: " << error_message; + Reset(QUIC_BAD_APPLICATION_PAYLOAD); +} + +void QuicSpdyStream::WriteCapsule(const Capsule& capsule, bool fin) { + QUIC_DLOG(INFO) << ENDPOINT << "Stream " << id() << " sending capsule " + << capsule; + QuicBuffer serialized_capsule = SerializeCapsule( + capsule, + spdy_session_->connection()->helper()->GetStreamSendBufferAllocator()); + QUICHE_DCHECK_GT(serialized_capsule.size(), 0u); + WriteOrBufferBody(serialized_capsule.AsStringView(), /*fin=*/fin); +} + +void QuicSpdyStream::WriteGreaseCapsule() { + // GREASE capsulde IDs have a form of 41 * N + 23. + QuicRandom* random = spdy_session_->connection()->random_generator(); + uint64_t type = random->InsecureRandUint64() >> 4; + type = (type / 41) * 41 + 23; + QUICHE_DCHECK_EQ((type - 23) % 41, 0u); + + constexpr size_t kMaxLength = 64; + size_t length = random->InsecureRandUint64() % kMaxLength; + std::string bytes(length, '\0'); + random->InsecureRandBytes(&bytes[0], bytes.size()); + Capsule capsule = Capsule::Unknown(type, bytes); + WriteCapsule(capsule, /*fin=*/false); +} + +MessageStatus QuicSpdyStream::SendHttp3Datagram( + absl::optional context_id, + absl::string_view payload) { + QuicDatagramStreamId stream_id = + datagram_flow_id_.has_value() ? datagram_flow_id_.value() : id(); + return spdy_session_->SendHttp3Datagram(stream_id, context_id, payload); +} + +void QuicSpdyStream::RegisterHttp3DatagramRegistrationVisitor( + Http3DatagramRegistrationVisitor* visitor, bool use_datagram_contexts) { + if (visitor == nullptr) { + QUIC_BUG(null datagram registration visitor) + << ENDPOINT << "Null datagram registration visitor for" << id(); + return; + } + if (datagram_registration_visitor_ != nullptr) { + QUIC_BUG(double datagram registration visitor) + << ENDPOINT << "Double datagram registration visitor for" << id(); + return; + } + use_datagram_contexts_ = use_datagram_contexts; + QUIC_DLOG(INFO) << ENDPOINT << "Registering datagram stream ID " << id() + << " with" << (use_datagram_contexts_ ? "" : "out") + << " contexts"; + datagram_registration_visitor_ = visitor; + QUICHE_DCHECK(!capsule_parser_); + capsule_parser_.reset(new CapsuleParser(this)); +} + +void QuicSpdyStream::UnregisterHttp3DatagramRegistrationVisitor() { + QUIC_BUG_IF(h3 datagram unregister unknown stream ID, + datagram_registration_visitor_ == nullptr) + << ENDPOINT + << "Attempted to unregister unknown HTTP/3 datagram stream ID " << id(); + QUIC_DLOG(INFO) << ENDPOINT << "Unregistering datagram stream ID " << id(); + datagram_registration_visitor_ = nullptr; +} + +void QuicSpdyStream::MoveHttp3DatagramRegistration( + Http3DatagramRegistrationVisitor* visitor) { + QUIC_BUG_IF(h3 datagram move unknown stream ID, + datagram_registration_visitor_ == nullptr) + << ENDPOINT << "Attempted to move unknown HTTP/3 datagram stream ID " + << id(); + QUIC_DLOG(INFO) << ENDPOINT << "Moving datagram stream ID " << id(); + datagram_registration_visitor_ = visitor; +} + +void QuicSpdyStream::RegisterHttp3DatagramContextId( + absl::optional context_id, + DatagramFormatType format_type, absl::string_view format_additional_data, + Http3DatagramVisitor* visitor) { + if (visitor == nullptr) { + QUIC_BUG(null datagram visitor) + << ENDPOINT << "Null datagram visitor for stream ID " << id() + << " context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none"); + return; + } + if (datagram_registration_visitor_ == nullptr) { + QUIC_BUG(context registration without registration visitor) + << ENDPOINT << "Cannot register context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none") + << " without registration visitor for stream ID " << id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT << "Registering datagram context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) + : "none") + << " with stream ID " << id(); + + if (context_id.has_value()) { + if (datagram_no_context_visitor_ != nullptr) { + QUIC_BUG(h3 datagram context ID mix1) + << ENDPOINT + << "Attempted to mix registrations without and with context IDs " + "for stream ID " + << id(); + return; + } + auto insertion_result = + datagram_context_visitors_.insert({context_id.value(), visitor}); + if (!insertion_result.second) { + QUIC_BUG(h3 datagram double context registration) + << ENDPOINT << "Attempted to doubly register HTTP/3 stream ID " + << id() << " context ID " << context_id.value(); + return; + } + capsule_parser_->set_datagram_context_id_present(true); + } else { + // Registration without a context ID. + if (!datagram_context_visitors_.empty()) { + QUIC_BUG(h3 datagram context ID mix2) + << ENDPOINT + << "Attempted to mix registrations with and without context IDs " + "for stream ID " + << id(); + return; + } + if (datagram_no_context_visitor_ != nullptr) { + QUIC_BUG(h3 datagram double no context registration) + << ENDPOINT << "Attempted to doubly register HTTP/3 stream ID " + << id() << " with no context ID"; + return; + } + datagram_no_context_visitor_ = visitor; + capsule_parser_->set_datagram_context_id_present(false); + } + if (spdy_session_->http_datagram_support() == HttpDatagramSupport::kDraft04) { + const bool is_client = session()->perspective() == Perspective::IS_CLIENT; + if (context_id.has_value()) { + const bool is_client_context = context_id.value() % 2 == 0; + if (is_client == is_client_context) { + QuicConnection::ScopedPacketFlusher flusher( + spdy_session_->connection()); + WriteGreaseCapsule(); + if (context_id.value() != 0) { + WriteCapsule(Capsule::RegisterDatagramContext( + context_id.value(), format_type, format_additional_data)); + } else { + // draft-ietf-masque-h3-datagram-05 encodes context ID 0 using + // REGISTER_DATAGRAM_NO_CONTEXT. + WriteCapsule(Capsule::RegisterDatagramNoContext( + format_type, format_additional_data)); + } + WriteGreaseCapsule(); + } + } else if (is_client) { + QuicConnection::ScopedPacketFlusher flusher(spdy_session_->connection()); + WriteGreaseCapsule(); + WriteCapsule(Capsule::RegisterDatagramNoContext(format_type, + format_additional_data)); + WriteGreaseCapsule(); + } + } +} + +void QuicSpdyStream::UnregisterHttp3DatagramContextId( + absl::optional context_id) { + if (datagram_registration_visitor_ == nullptr) { + QUIC_BUG(context unregistration without registration visitor) + << ENDPOINT << "Cannot unregister context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none") + << " without registration visitor for stream ID " << id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT << "Unregistering datagram context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) + : "none") + << " with stream ID " << id(); + if (context_id.has_value()) { + size_t num_erased = datagram_context_visitors_.erase(context_id.value()); + QUIC_BUG_IF(h3 datagram unregister unknown context, num_erased != 1) + << "Attempted to unregister unknown HTTP/3 context ID " + << context_id.value() << " on stream ID " << id(); + } else { + // Unregistration without a context ID. + QUIC_BUG_IF(h3 datagram unknown context unregistration, + datagram_no_context_visitor_ == nullptr) + << "Attempted to unregister unknown no context on HTTP/3 stream ID " + << id(); + datagram_no_context_visitor_ = nullptr; + } + if (spdy_session_->http_datagram_support() == HttpDatagramSupport::kDraft04 && + context_id.has_value()) { + WriteCapsule(Capsule::CloseDatagramContext(context_id.value())); + } +} + +void QuicSpdyStream::MoveHttp3DatagramContextIdRegistration( + absl::optional context_id, + Http3DatagramVisitor* visitor) { + if (datagram_registration_visitor_ == nullptr) { + QUIC_BUG(context move without registration visitor) + << ENDPOINT << "Cannot move context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none") + << " without registration visitor for stream ID " << id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT << "Moving datagram context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) + : "none") + << " with stream ID " << id(); + if (context_id.has_value()) { + QUIC_BUG_IF(h3 datagram move unknown context, + !datagram_context_visitors_.contains(context_id.value())) + << ENDPOINT << "Attempted to move unknown context ID " + << context_id.value() << " on stream ID " << id(); + datagram_context_visitors_[context_id.value()] = visitor; + return; + } + // Move without a context ID. + QUIC_BUG_IF(h3 datagram unknown context move, + datagram_no_context_visitor_ == nullptr) + << "Attempted to move unknown no context on HTTP/3 stream ID " << id(); + datagram_no_context_visitor_ = visitor; +} + +void QuicSpdyStream::SetMaxDatagramTimeInQueue( + QuicTime::Delta max_time_in_queue) { + spdy_session_->SetMaxDatagramTimeInQueueForStreamId(id(), max_time_in_queue); +} + +QuicDatagramContextId QuicSpdyStream::GetNextDatagramContextId() { + QuicDatagramContextId result = datagram_next_available_context_id_; + datagram_next_available_context_id_ += kDatagramContextIdIncrement; + return result; +} + +void QuicSpdyStream::OnDatagramReceived(QuicDataReader* reader) { + if (!headers_decompressed_) { + QUIC_DLOG(INFO) << "Dropping datagram received before headers on stream ID " + << id(); + return; + } + absl::optional context_id; + if (use_datagram_contexts_) { + QuicDatagramContextId parsed_context_id; + if (!reader->ReadVarInt62(&parsed_context_id)) { + QUIC_DLOG(ERROR) << "Failed to parse context ID in received HTTP/3 " + "datagram on stream ID " + << id(); + return; + } + context_id = parsed_context_id; + } + absl::string_view payload = reader->ReadRemainingPayload(); + HandleReceivedDatagram(context_id, payload); +} + +QuicByteCount QuicSpdyStream::GetMaxDatagramSize( + absl::optional context_id) const { + QuicByteCount prefix_size = 0; + switch (spdy_session_->http_datagram_support()) { + case HttpDatagramSupport::kDraft00: + if (!datagram_flow_id_.has_value()) { + QUIC_BUG(GetMaxDatagramSize with no flow ID) + << "GetMaxDatagramSize() called when no flow ID available"; + break; + } + prefix_size = QuicDataWriter::GetVarInt62Len(*datagram_flow_id_); + break; + case HttpDatagramSupport::kDraft04: + prefix_size = + QuicDataWriter::GetVarInt62Len(id() / kHttpDatagramStreamIdDivisor); + break; + case HttpDatagramSupport::kNone: + case HttpDatagramSupport::kDraft00And04: + QUIC_BUG(GetMaxDatagramSize called with no datagram support) + << "GetMaxDatagramSize() called when no HTTP/3 datagram support has " + "been negotiated. Support value: " + << spdy_session_->http_datagram_support(); + break; + } + // If the logic above fails, use the largest possible value as the safe one. + if (prefix_size == 0) { + prefix_size = 8; + } + + if (context_id.has_value()) { + QUIC_BUG_IF( + context_id with draft00 in GetMaxDatagramSize, + spdy_session_->http_datagram_support() == HttpDatagramSupport::kDraft00) + << "GetMaxDatagramSize() called with a context ID specified, but " + "draft00 does not support contexts."; + prefix_size += QuicDataWriter::GetVarInt62Len(*context_id); + } + + QuicByteCount max_datagram_size = + session()->GetGuaranteedLargestMessagePayload(); + if (max_datagram_size < prefix_size) { + QUIC_BUG(max_datagram_size smaller than prefix_size) + << "GetGuaranteedLargestMessagePayload() returned a datagram size that " + "is not sufficient to fit stream and/or context ID into it."; + return 0; + } + return max_datagram_size - prefix_size; +} + +void QuicSpdyStream::RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id) { + datagram_flow_id_ = flow_id; + spdy_session_->RegisterHttp3DatagramFlowId(datagram_flow_id_.value(), id()); +} + +void QuicSpdyStream::HandleBodyAvailable() { + if (!capsule_parser_) { + OnBodyAvailable(); + return; + } + while (body_manager_.HasBytesToRead()) { + iovec iov; + int num_iov = GetReadableRegions(&iov, /*iov_len=*/1); + if (num_iov == 0) { + break; + } + if (!capsule_parser_->IngestCapsuleFragment(absl::string_view( + reinterpret_cast(iov.iov_base), iov.iov_len))) { + break; + } + MarkConsumed(iov.iov_len); + } + // If we received a FIN, make sure that there isn't a partial capsule buffered + // in the capsule parser. + if (sequencer()->IsClosed()) { + capsule_parser_->ErrorIfThereIsRemainingBufferedData(); + if (web_transport_ != nullptr) { + web_transport_->OnConnectStreamFinReceived(); + } + OnFinRead(); + } +} + +namespace { +// Return true if |c| is not allowed in an HTTP/3 wire-encoded header and +// pseudo-header names according to +// https://datatracker.ietf.org/doc/html/draft-ietf-quic-http#section-4.1.1 and +// https://datatracker.ietf.org/doc/html/draft-ietf-httpbis-semantics-19#section-5.6.2 +constexpr bool isInvalidHeaderNameCharacter(unsigned char c) { + if (c == '!' || c == '|' || c == '~' || c == '*' || c == '+' || c == '-' || + c == '.' || + // #, $, %, &, ' + (c >= '#' && c <= '\'') || + // [0-9], : + (c >= '0' && c <= ':') || + // ^, _, `, [a-z] + (c >= '^' && c <= 'z')) { + return false; + } + return true; +} +} // namespace + +bool QuicSpdyStream::AreHeadersValid(const QuicHeaderList& header_list) const { + QUICHE_DCHECK(GetQuicReloadableFlag(quic_verify_request_headers_2)); + for (const std::pair& pair : header_list) { + const std::string& name = pair.first; + if (std::any_of(name.begin(), name.end(), isInvalidHeaderNameCharacter)) { + QUIC_DLOG(ERROR) << "Invalid request header " << name; + return false; + } + if (http2::GetInvalidHttp2HeaderSet().contains(name)) { + QUIC_DLOG(ERROR) << name << " header is not allowed"; + return false; + } + } + return true; +} + +void QuicSpdyStream::OnInvalidHeaders() { Reset(QUIC_BAD_APPLICATION_PAYLOAD); } + #undef ENDPOINT // undef for jumbo builds } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_stream.h b/gquiche/quic/core/http/quic_spdy_stream.h index 14f92d43..28322f41 100644 --- a/gquiche/quic/core/http/quic_spdy_stream.h +++ b/gquiche/quic/core/http/quic_spdy_stream.h @@ -16,18 +16,22 @@ #include #include +#include "absl/base/attributes.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "gquiche/quic/core/http/capsule.h" #include "gquiche/quic/core/http/http_decoder.h" #include "gquiche/quic/core/http/http_encoder.h" #include "gquiche/quic/core/http/quic_header_list.h" #include "gquiche/quic/core/http/quic_spdy_stream_body_manager.h" +#include "gquiche/quic/core/http/web_transport_stream_adapter.h" #include "gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_stream.h" #include "gquiche/quic/core/quic_stream_sequencer.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/web_transport_interface.h" -#include "gquiche/quic/core/web_transport_stream_adapter.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_socket_address.h" @@ -47,6 +51,7 @@ class WebTransportHttp3; // A QUIC stream that can send and receive HTTP2 (SPDY) headers. class QUIC_EXPORT_PRIVATE QuicSpdyStream : public QuicStream, + public CapsuleParser::Visitor, public QpackDecodedHeadersAccumulator::Visitor { public: // Visitor receives callbacks from the stream. @@ -67,12 +72,9 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream virtual ~Visitor() {} }; - QuicSpdyStream(QuicStreamId id, - QuicSpdySession* spdy_session, - StreamType type); - QuicSpdyStream(PendingStream* pending, - QuicSpdySession* spdy_session, + QuicSpdyStream(QuicStreamId id, QuicSpdySession* spdy_session, StreamType type); + QuicSpdyStream(PendingStream* pending, QuicSpdySession* spdy_session); QuicSpdyStream(const QuicSpdyStream&) = delete; QuicSpdyStream& operator=(const QuicSpdyStream&) = delete; ~QuicSpdyStream() override; @@ -91,14 +93,12 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // Called by the session when decompressed headers have been completely // delivered to this stream. If |fin| is true, then this stream // should be closed; no more data will be sent by the peer. - virtual void OnStreamHeaderList(bool fin, - size_t frame_len, + virtual void OnStreamHeaderList(bool fin, size_t frame_len, const QuicHeaderList& header_list); // Called by the session when decompressed push promise headers have // been completely delivered to this stream. - virtual void OnPromiseHeaderList(QuicStreamId promised_id, - size_t frame_len, + virtual void OnPromiseHeaderList(QuicStreamId promised_id, size_t frame_len, const QuicHeaderList& header_list); // Called by the session when a PRIORITY frame has been been received for this @@ -108,8 +108,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // Override the base class to not discard response when receiving // QUIC_STREAM_NO_ERROR. void OnStreamReset(const QuicRstStreamFrame& frame) override; - - void Reset(QuicRstStreamErrorCode error) override; + void ResetWithError(QuicResetStreamError error) override; + bool OnStopSending(QuicResetStreamError error) override; // Called by the sequencer when new data is available. Decodes the data and // calls OnBodyAvailable() to pass to the upper layer. @@ -123,8 +123,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // number of bytes sent, including data sent on the encoder stream when using // QPACK. virtual size_t WriteHeaders( - spdy::SpdyHeaderBlock header_block, - bool fin, + spdy::SpdyHeaderBlock header_block, bool fin, QuicReferenceCountedPointer ack_listener); // Sends |data| to the peer, or buffers if it can't be sent immediately. @@ -139,10 +138,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream QuicReferenceCountedPointer ack_listener); // Override to report newly acked bytes via ack_listener_. - bool OnStreamFrameAcked(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin_acked, - QuicTime::Delta ack_delay_time, + bool OnStreamFrameAcked(QuicStreamOffset offset, QuicByteCount data_length, + bool fin_acked, QuicTime::Delta ack_delay_time, QuicTime receive_timestamp, QuicByteCount* newly_acked_length) override; @@ -157,7 +154,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // Does the same thing as WriteOrBufferBody except this method takes // memslicespan as the data input. Right now it only calls WriteMemSlices. - QuicConsumedData WriteBodySlices(QuicMemSliceSpan slices, bool fin); + QuicConsumedData WriteBodySlices(absl::Span slices, bool fin); // Marks the trailers as consumed. This applies to the case where this object // receives headers and trailers as QuicHeaderLists via calls to @@ -215,7 +212,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // QpackDecodedHeadersAccumulator::Visitor implementation. void OnHeadersDecoded(QuicHeaderList headers, bool header_list_size_limit_exceeded) override; - void OnHeaderDecodingError(absl::string_view error_message) override; + void OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) override; QuicSpdySession* spdy_session() const { return spdy_session_; } @@ -244,20 +242,123 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // error, and returns false. bool AssertNotWebTransportDataStream(absl::string_view operation); + // Indicates whether a call to WriteBodySlices will be successful and not + // rejected due to buffer being full. |write_size| must be non-zero. + bool CanWriteNewBodyData(QuicByteCount write_size) const; + + // From CapsuleParser::Visitor. + bool OnCapsule(const Capsule& capsule) override; + void OnCapsuleParseFailure(const std::string& error_message) override; + + // Sends an HTTP/3 datagram. The stream and context IDs are not part of + // |payload|. + MessageStatus SendHttp3Datagram( + absl::optional context_id, + absl::string_view payload); + + class QUIC_EXPORT_PRIVATE Http3DatagramVisitor { + public: + virtual ~Http3DatagramVisitor() {} + + // Called when an HTTP/3 datagram is received. |payload| does not contain + // the stream or context IDs. Note that this contains the stream ID even if + // flow IDs from draft-ietf-masque-h3-datagram-00 are in use. + virtual void OnHttp3Datagram( + QuicStreamId stream_id, + absl::optional context_id, + absl::string_view payload) = 0; + }; + + class QUIC_EXPORT_PRIVATE Http3DatagramRegistrationVisitor { + public: + virtual ~Http3DatagramRegistrationVisitor() {} + + // Called when a REGISTER_DATAGRAM_CONTEXT or REGISTER_DATAGRAM_NO_CONTEXT + // capsule is received. Note that this contains the stream ID even if flow + // IDs from draft-ietf-masque-h3-datagram-00 are in use. + virtual void OnContextReceived( + QuicStreamId stream_id, + absl::optional context_id, + DatagramFormatType format_type, + absl::string_view format_additional_data) = 0; + + // Called when a CLOSE_DATAGRAM_CONTEXT capsule is received. Note that this + // contains the stream ID even if flow IDs from + // draft-ietf-masque-h3-datagram-00 are in use. + virtual void OnContextClosed( + QuicStreamId stream_id, + absl::optional context_id, + ContextCloseCode close_code, absl::string_view close_details) = 0; + }; + + // Registers |visitor| to receive HTTP/3 datagram context registrations. This + // must not be called without first calling + // UnregisterHttp3DatagramRegistrationVisitor. |visitor| must be valid until a + // corresponding call to UnregisterHttp3DatagramRegistrationVisitor. + void RegisterHttp3DatagramRegistrationVisitor( + Http3DatagramRegistrationVisitor* visitor, + bool use_datagram_contexts = false); + + // Unregisters for HTTP/3 datagram context registrations. Must not be called + // unless previously registered. + void UnregisterHttp3DatagramRegistrationVisitor(); + + // Moves an HTTP/3 datagram registration to a different visitor. Mainly meant + // to be used by the visitors' move operators. + void MoveHttp3DatagramRegistration(Http3DatagramRegistrationVisitor* visitor); + + // Registers |visitor| to receive HTTP/3 datagrams for optional context ID + // |context_id|. This must not be called on a previously registered context ID + // without first calling UnregisterHttp3DatagramContextId. |visitor| must be + // valid until a corresponding call to UnregisterHttp3DatagramContextId. If + // this method is called multiple times, the context ID MUST either be always + // present, or always absent. + void RegisterHttp3DatagramContextId( + absl::optional context_id, + DatagramFormatType format_type, absl::string_view format_additional_data, + Http3DatagramVisitor* visitor); + + // Unregisters an HTTP/3 datagram context ID. Must be called on a previously + // registered context. + void UnregisterHttp3DatagramContextId( + absl::optional context_id); + + // Moves an HTTP/3 datagram context ID to a different visitor. Mainly meant + // to be used by the visitors' move operators. + void MoveHttp3DatagramContextIdRegistration( + absl::optional context_id, + Http3DatagramVisitor* visitor); + + // Sets max datagram time in queue. + void SetMaxDatagramTimeInQueue(QuicTime::Delta max_time_in_queue); + + // Generates a new HTTP/3 datagram context ID for this stream. A datagram + // registration visitor must be currently registered on this stream. + QuicDatagramContextId GetNextDatagramContextId(); + + void OnDatagramReceived(QuicDataReader* reader); + + void RegisterHttp3DatagramFlowId(QuicDatagramStreamId flow_id); + + QuicByteCount GetMaxDatagramSize( + absl::optional context_id) const; + + // Writes |capsule| onto the DATA stream. + void WriteCapsule(const Capsule& capsule, bool fin = false); + + void WriteGreaseCapsule(); + protected: // Called when the received headers are too large. By default this will // reset the stream. virtual void OnHeadersTooLarge(); - virtual void OnInitialHeadersComplete(bool fin, - size_t frame_len, + virtual void OnInitialHeadersComplete(bool fin, size_t frame_len, const QuicHeaderList& header_list); - virtual void OnTrailingHeadersComplete(bool fin, - size_t frame_len, + virtual void OnTrailingHeadersComplete(bool fin, size_t frame_len, const QuicHeaderList& header_list); virtual size_t WriteHeadersImpl( - spdy::SpdyHeaderBlock header_block, - bool fin, + spdy::SpdyHeaderBlock header_block, bool fin, QuicReferenceCountedPointer ack_listener); Visitor* visitor() { return visitor_; } @@ -269,6 +370,13 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream ack_listener_ = std::move(ack_listener); } + void OnWriteSideInDataRecvdState() override; + + virtual bool AreHeadersValid(const QuicHeaderList& header_list) const; + + // Reset stream upon invalid request headers. + virtual void OnInvalidHeaders(); + private: friend class test::QuicSpdyStreamPeer; friend class test::QuicStreamPeer; @@ -292,16 +400,9 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream QuicByteCount payload_length); bool OnHeadersFramePayload(absl::string_view payload); bool OnHeadersFrameEnd(); - bool OnPushPromiseFrameStart(QuicByteCount header_length); - bool OnPushPromiseFramePushId(PushId push_id, - QuicByteCount push_id_length, - QuicByteCount header_block_length); - bool OnPushPromiseFramePayload(absl::string_view payload); - bool OnPushPromiseFrameEnd(); void OnWebTransportStreamFrameType(QuicByteCount header_length, WebTransportSessionId session_id); - bool OnUnknownFrameStart(uint64_t frame_type, - QuicByteCount header_length, + bool OnUnknownFrameStart(uint64_t frame_type, QuicByteCount header_length, QuicByteCount payload_length); bool OnUnknownFramePayload(absl::string_view payload); bool OnUnknownFrameEnd(); @@ -314,6 +415,22 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream void MaybeProcessSentWebTransportHeaders(spdy::SpdyHeaderBlock& headers); void MaybeProcessReceivedWebTransportHeaders(); + // Writes HTTP/3 DATA frame header. If |force_write| is true, use + // WriteOrBufferData if send buffer cannot accomodate the header + data. + ABSL_MUST_USE_RESULT bool WriteDataFrameHeader(QuicByteCount data_length, + bool force_write); + + // Simply calls OnBodyAvailable() unless capsules are in use, in which case + // pass the capsule fragments to the capsule manager. + void HandleBodyAvailable(); + + // Called when a datagram frame or capsule is received. + void HandleReceivedDatagram(absl::optional context_id, + absl::string_view payload); + + // Whether datagram contexts should be used on this stream. + bool ShouldUseDatagramContexts() const; + QuicSpdySession* spdy_session_; bool on_body_available_called_because_sequencer_is_closed_; @@ -354,6 +471,8 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // the sequencer each time new stream data is processed. QuicSpdyStreamBodyManager body_manager_; + std::unique_ptr capsule_parser_; + // Sequencer offset keeping track of how much data HttpDecoder has processed. // Initial value is zero for fresh streams, or sequencer()->NumBytesConsumed() // at time of construction if a PendingStream is converted to account for the @@ -382,6 +501,15 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStream // If this stream is a WebTransport data stream, |web_transport_data_| // contains all of the associated metadata. std::unique_ptr web_transport_data_; + + // HTTP/3 Datagram support. + Http3DatagramRegistrationVisitor* datagram_registration_visitor_ = nullptr; + Http3DatagramVisitor* datagram_no_context_visitor_ = nullptr; + absl::optional datagram_flow_id_; + QuicDatagramContextId datagram_next_available_context_id_; + absl::flat_hash_map + datagram_context_visitors_; + bool use_datagram_contexts_ = false; }; } // namespace quic diff --git a/gquiche/quic/core/http/quic_spdy_stream_body_manager.h b/gquiche/quic/core/http/quic_spdy_stream_body_manager.h index 765c3a1b..9292e143 100644 --- a/gquiche/quic/core/http/quic_spdy_stream_body_manager.h +++ b/gquiche/quic/core/http/quic_spdy_stream_body_manager.h @@ -7,11 +7,11 @@ #include "absl/base/attributes.h" #include "absl/strings/string_view.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_constants.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_iovec.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -84,7 +84,7 @@ class QUIC_EXPORT_PRIVATE QuicSpdyStreamBodyManager { QuicByteCount trailing_non_body_byte_count; }; // Queue of body fragments and trailing non-body byte counts. - QuicCircularDeque fragments_; + quiche::QuicheCircularDeque fragments_; // Total body bytes received. QuicByteCount total_body_bytes_received_; }; diff --git a/gquiche/quic/core/http/quic_spdy_stream_test.cc b/gquiche/quic/core/http/quic_spdy_stream_test.cc index 96f6452d..61e52030 100644 --- a/gquiche/quic/core/http/quic_spdy_stream_test.cc +++ b/gquiche/quic/core/http/quic_spdy_stream_test.cc @@ -16,15 +16,17 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/crypto/null_encrypter.h" #include "gquiche/quic/core/http/http_encoder.h" +#include "gquiche/quic/core/http/quic_spdy_session.h" #include "gquiche/quic/core/http/spdy_utils.h" #include "gquiche/quic/core/http/web_transport_http3.h" #include "gquiche/quic/core/quic_connection.h" +#include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/core/quic_stream_sequencer_buffer.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/core/quic_write_blocked_list.h" #include "gquiche/quic/platform/api/quic_expect_bug.h" -#include "gquiche/quic/platform/api/quic_map_util.h" +#include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" #include "gquiche/quic/test_tools/quic_config_peer.h" @@ -156,10 +158,19 @@ class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { ConnectionCloseSource /*source*/) override {} void OnHandshakeDoneReceived() override {} void OnNewTokenReceived(absl::string_view /*token*/) override {} - std::string GetAddressToken() const override { return ""; } + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } bool ValidateAddressToken(absl::string_view /*token*/) const override { return true; } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} MOCK_METHOD(void, OnCanWrite, (), (override)); @@ -167,6 +178,15 @@ class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { MOCK_METHOD(bool, HasPendingRetransmission, (), (const, override)); + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + + SSL* GetSsl() const override { return nullptr; } + private: using QuicCryptoStream::session; @@ -177,8 +197,7 @@ class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { class TestStream : public QuicSpdyStream { public: - TestStream(QuicStreamId id, - QuicSpdySession* session, + TestStream(QuicStreamId id, QuicSpdySession* session, bool should_process_data) : QuicSpdyStream(id, session, BIDIRECTIONAL), should_process_data_(should_process_data), @@ -203,8 +222,7 @@ class TestStream : public QuicSpdyStream { MOCK_METHOD(void, WriteHeadersMock, (bool fin), ()); - size_t WriteHeadersImpl(spdy::SpdyHeaderBlock header_block, - bool fin, + size_t WriteHeadersImpl(spdy::SpdyHeaderBlock header_block, bool fin, QuicReferenceCountedPointer /*ack_listener*/) override { saved_headers_ = std::move(header_block); @@ -226,8 +244,7 @@ class TestStream : public QuicSpdyStream { return QuicStream::sequencer(); } - void OnStreamHeaderList(bool fin, - size_t frame_len, + void OnStreamHeaderList(bool fin, size_t frame_len, const QuicHeaderList& header_list) override { headers_payload_length_ = frame_len; QuicSpdyStream::OnStreamHeaderList(fin, frame_len, header_list); @@ -235,6 +252,11 @@ class TestStream : public QuicSpdyStream { size_t headers_payload_length() const { return headers_payload_length_; } + bool AreHeadersValid(const QuicHeaderList& header_list) const override { + return !GetQuicReloadableFlag(quic_verify_request_headers_2) || + QuicSpdyStream::AreHeadersValid(header_list); + } + private: bool should_process_data_; spdy::SpdyHeaderBlock saved_headers_; @@ -259,8 +281,16 @@ class TestSession : public MockQuicSpdySession { bool ShouldNegotiateWebTransport() override { return enable_webtransport_; } void EnableWebTransport() { enable_webtransport_ = true; } + HttpDatagramSupport LocalHttpDatagramSupport() override { + return local_http_datagram_support_; + } + void set_local_http_datagram_support(HttpDatagramSupport value) { + local_http_datagram_support_ = value; + } + private: bool enable_webtransport_ = false; + HttpDatagramSupport local_http_datagram_support_ = HttpDatagramSupport::kNone; StrictMock crypto_stream_; }; @@ -272,8 +302,7 @@ class TestMockUpdateStreamSession : public MockQuicSpdySession { spdy::SpdyStreamPrecedence(QuicStream::kDefaultPriority)) {} void UpdateStreamPriority( - QuicStreamId id, - const spdy::SpdyStreamPrecedence& precedence) override { + QuicStreamId id, const spdy::SpdyStreamPrecedence& precedence) override { EXPECT_EQ(id, expected_stream_->id()); EXPECT_EQ(expected_precedence_, precedence); EXPECT_EQ(expected_precedence_, expected_stream_->precedence()); @@ -447,40 +476,10 @@ class QuicSpdyStreamTest : public QuicTestWithParam { return absl::StrCat(headers_frame_header, payload); } - // Construct PUSH_PROMISE frame with given payload. - // TODO(b/171463363): Remove. - std::string SerializePushPromiseFrame(PushId push_id, - absl::string_view headers) { - const QuicByteCount payload_length = - QuicDataWriter::GetVarInt62Len(push_id) + headers.length(); - - const QuicByteCount length_without_headers = - QuicDataWriter::GetVarInt62Len( - static_cast(HttpFrameType::PUSH_PROMISE)) + - QuicDataWriter::GetVarInt62Len(payload_length) + - QuicDataWriter::GetVarInt62Len(push_id); - - std::string push_promise_frame(length_without_headers, '\0'); - QuicDataWriter writer(length_without_headers, &*push_promise_frame.begin()); - - QUICHE_CHECK(writer.WriteVarInt62( - static_cast(HttpFrameType::PUSH_PROMISE))); - QUICHE_CHECK(writer.WriteVarInt62(payload_length)); - QUICHE_CHECK(writer.WriteVarInt62(push_id)); - QUICHE_CHECK_EQ(0u, writer.remaining()); - - absl::StrAppend(&push_promise_frame, headers); - - return push_promise_frame; - } - std::string DataFrame(absl::string_view payload) { - std::unique_ptr data_buffer; - QuicByteCount data_frame_header_length = - HttpEncoder::SerializeDataFrameHeader(payload.length(), &data_buffer); - absl::string_view data_frame_header(data_buffer.get(), - data_frame_header_length); - return absl::StrCat(data_frame_header, payload); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + payload.length(), SimpleBufferAllocator::Get()); + return absl::StrCat(header.AsStringView(), payload); } std::string UnknownFrame(uint64_t frame_type, absl::string_view payload) { @@ -513,8 +512,7 @@ class QuicSpdyStreamTest : public QuicTestWithParam { SpdyHeaderBlock headers_; }; -INSTANTIATE_TEST_SUITE_P(Tests, - QuicSpdyStreamTest, +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdyStreamTest, ::testing::ValuesIn(AllSupportedVersions()), ::testing::PrintToStringParamName()); @@ -537,8 +535,11 @@ TEST_P(QuicSpdyStreamTest, ProcessTooLargeHeaderList) { stream_->OnStreamHeadersPriority( spdy::SpdyStreamPrecedence(kV3HighestPriority)); - EXPECT_CALL(*session_, MaybeSendRstStreamFrame(stream_->id(), - QUIC_HEADERS_TOO_LARGE, 0)); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_HEADERS_TOO_LARGE), 0)); stream_->OnStreamHeaderList(false, 1 << 20, headers); EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_HEADERS_TOO_LARGE)); @@ -553,10 +554,14 @@ TEST_P(QuicSpdyStreamTest, ProcessTooLargeHeaderList) { QuicStreamFrame frame(stream_->id(), false, 0, headers); - EXPECT_CALL(*session_, - MaybeSendStopSendingFrame(stream_->id(), QUIC_HEADERS_TOO_LARGE)); - EXPECT_CALL(*session_, MaybeSendRstStreamFrame(stream_->id(), - QUIC_HEADERS_TOO_LARGE, 0)); + EXPECT_CALL(*session_, MaybeSendStopSendingFrame( + stream_->id(), QuicResetStreamError::FromInternal( + QUIC_HEADERS_TOO_LARGE))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_HEADERS_TOO_LARGE), 0)); auto qpack_decoder_stream = QuicSpdySessionPeer::GetQpackDecoderSendStream(session_.get()); @@ -1005,11 +1010,10 @@ TEST_P(QuicSpdyStreamTest, StreamFlowControlNoWindowUpdateIfNotConsumed) { std::string data; if (UsesHttp3()) { - std::unique_ptr buffer; - header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - data = header + body; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); + data = absl::StrCat(header.AsStringView(), body); + header_length = header.size(); } else { data = body; } @@ -1049,11 +1053,10 @@ TEST_P(QuicSpdyStreamTest, StreamFlowControlWindowUpdate) { std::string data; if (UsesHttp3()) { - std::unique_ptr buffer; - header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - data = header + body; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); + data = absl::StrCat(header.AsStringView(), body); + header_length = header.size(); } else { data = body; } @@ -1115,16 +1118,13 @@ TEST_P(QuicSpdyStreamTest, ConnectionFlowControlWindowUpdate) { if (UsesHttp3()) { body = std::string(kWindow / 4 - 2, 'a'); - std::unique_ptr buffer; - header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - data = header + body; - std::unique_ptr buffer2; - QuicByteCount header_length2 = - HttpEncoder::SerializeDataFrameHeader(body2.length(), &buffer2); - std::string header2 = std::string(buffer2.get(), header_length2); - data2 = header2 + body2; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); + data = absl::StrCat(header.AsStringView(), body); + header_length = header.size(); + QuicBuffer header2 = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); + data2 = absl::StrCat(header2.AsStringView(), body2); } else { body = std::string(kWindow / 4, 'a'); data = body; @@ -1595,8 +1595,9 @@ TEST_P(QuicSpdyStreamTest, WritingTrailersFinalOffset) { std::string body(1024, 'x'); // 1 kB QuicByteCount header_length = 0; if (UsesHttp3()) { - std::unique_ptr buf; - header_length = HttpEncoder::SerializeDataFrameHeader(body.length(), &buf); + header_length = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()) + .size(); } stream_->WriteOrBufferBody(body, false); @@ -1939,30 +1940,26 @@ TEST_P(QuicSpdyStreamTest, HeadersAckNotReportedWriteOrBufferBody) { stream_->WriteOrBufferBody(body, false); stream_->WriteOrBufferBody(body2, true); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - - header_length = - HttpEncoder::SerializeDataFrameHeader(body2.length(), &buffer); - std::string header2 = std::string(buffer.get(), header_length); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); + QuicBuffer header2 = HttpEncoder::SerializeDataFrameHeader( + body2.length(), SimpleBufferAllocator::Get()); EXPECT_CALL(*mock_ack_listener, OnPacketAcked(body.length(), _)); - QuicStreamFrame frame(stream_->id(), false, 0, header + body); + QuicStreamFrame frame(stream_->id(), false, 0, + absl::StrCat(header.AsStringView(), body)); EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame), QuicTime::Delta::Zero(), QuicTime::Zero())); EXPECT_CALL(*mock_ack_listener, OnPacketAcked(0, _)); - QuicStreamFrame frame2(stream_->id(), false, header.length() + body.length(), - header2); + QuicStreamFrame frame2(stream_->id(), false, header.size() + body.length(), + header2.AsStringView()); EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame2), QuicTime::Delta::Zero(), QuicTime::Zero())); EXPECT_CALL(*mock_ack_listener, OnPacketAcked(body2.length(), _)); QuicStreamFrame frame3(stream_->id(), true, - header.length() + body.length() + header2.length(), - body2); + header.size() + body.length() + header2.size(), body2); EXPECT_TRUE(session_->OnFrameAcked(QuicFrame(frame3), QuicTime::Delta::Zero(), QuicTime::Zero())); @@ -2497,10 +2494,14 @@ TEST_P(QuicSpdyStreamTest, HeaderDecodingUnblockedAfterStreamClosed) { /* offset = */ 1, _, _, _)); // Reset stream by this endpoint, for example, due to stream cancellation. - EXPECT_CALL(*session_, - MaybeSendStopSendingFrame(stream_->id(), QUIC_STREAM_CANCELLED)); - EXPECT_CALL(*session_, - MaybeSendRstStreamFrame(stream_->id(), QUIC_STREAM_CANCELLED, _)); + EXPECT_CALL(*session_, MaybeSendStopSendingFrame( + stream_->id(), QuicResetStreamError::FromInternal( + QUIC_STREAM_CANCELLED))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), _)); stream_->Reset(QUIC_STREAM_CANCELLED); // Deliver dynamic table entry to decoder. @@ -2605,8 +2606,7 @@ class QuicSpdyStreamIncrementalConsumptionTest : public QuicSpdyStreamTest { QuicStreamOffset consumed_bytes_; }; -INSTANTIATE_TEST_SUITE_P(Tests, - QuicSpdyStreamIncrementalConsumptionTest, +INSTANTIATE_TEST_SUITE_P(Tests, QuicSpdyStreamIncrementalConsumptionTest, ::testing::ValuesIn(AllSupportedVersions()), ::testing::PrintToStringParamName()); @@ -2769,65 +2769,6 @@ TEST_P(QuicSpdyStreamIncrementalConsumptionTest, UnknownFramesInterleaved) { EXPECT_EQ(unknown_frame4.size(), NewlyConsumedBytes()); } -// TODO(b/171463363): Remove. -TEST_P(QuicSpdyStreamTest, PushPromiseOnDataStream) { - Initialize(kShouldProcessData); - if (!UsesHttp3()) { - return; - } - - StrictMock debug_visitor; - session_->set_debug_visitor(&debug_visitor); - - SpdyHeaderBlock pushed_headers; - pushed_headers["foo"] = "bar"; - std::string headers = EncodeQpackHeaders(pushed_headers); - - const QuicStreamId push_id = 1; - std::string data = SerializePushPromiseFrame(push_id, headers); - QuicStreamFrame frame(stream_->id(), false, 0, data); - - EXPECT_CALL(debug_visitor, OnPushPromiseFrameReceived(stream_->id(), push_id, - headers.length())); - EXPECT_CALL(debug_visitor, - OnPushPromiseDecoded(stream_->id(), push_id, - AsHeaderList(pushed_headers))); - EXPECT_CALL(*session_, - OnPromiseHeaderList(stream_->id(), push_id, headers.length(), _)); - stream_->OnStreamFrame(frame); -} - -// Regression test for b/152518220. -// TODO(b/171463363): Remove. -TEST_P(QuicSpdyStreamTest, - OnStreamHeaderBlockArgumentDoesNotIncludePushedHeaderBlock) { - Initialize(kShouldProcessData); - if (!UsesHttp3()) { - return; - } - - std::string pushed_headers = EncodeQpackHeaders({{"foo", "bar"}}); - const QuicStreamId push_id = 1; - std::string push_promise_frame = - SerializePushPromiseFrame(push_id, pushed_headers); - QuicStreamOffset offset = 0; - QuicStreamFrame frame1(stream_->id(), /* fin = */ false, offset, - push_promise_frame); - offset += push_promise_frame.length(); - - EXPECT_CALL(*session_, OnPromiseHeaderList(stream_->id(), push_id, - pushed_headers.length(), _)); - stream_->OnStreamFrame(frame1); - - std::string headers = - EncodeQpackHeaders({{":method", "GET"}, {":path", "/"}}); - std::string headers_frame = HeadersFrame(headers); - QuicStreamFrame frame2(stream_->id(), /* fin = */ false, offset, - headers_frame); - stream_->OnStreamFrame(frame2); - EXPECT_EQ(headers.length(), stream_->headers_payload_length()); -} - // Close connection if a DATA frame is received before a HEADERS frame. TEST_P(QuicSpdyStreamTest, DataBeforeHeaders) { if (!UsesHttp3()) { @@ -3003,10 +2944,14 @@ TEST_P(QuicSpdyStreamTest, StreamCancellationWhenStreamReset) { EXPECT_CALL(*session_, WritevData(qpack_decoder_stream->id(), /* write_length = */ 1, /* offset = */ 1, _, _, _)); - EXPECT_CALL(*session_, - MaybeSendStopSendingFrame(stream_->id(), QUIC_STREAM_CANCELLED)); - EXPECT_CALL(*session_, - MaybeSendRstStreamFrame(stream_->id(), QUIC_STREAM_CANCELLED, _)); + EXPECT_CALL(*session_, MaybeSendStopSendingFrame( + stream_->id(), QuicResetStreamError::FromInternal( + QUIC_STREAM_CANCELLED))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), _)); stream_->Reset(QUIC_STREAM_CANCELLED); } @@ -3098,22 +3043,29 @@ TEST_P(QuicSpdyStreamTest, TwoResetStreamFrames) { EXPECT_TRUE(stream_->read_side_closed()); EXPECT_FALSE(stream_->write_side_closed()); } else { - EXPECT_CALL(*session_, MaybeSendRstStreamFrame( - stream_->id(), QUIC_RST_ACKNOWLEDGEMENT, _)); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), _)); EXPECT_QUIC_BUG( stream_->OnStreamReset(rst_frame2), "The stream should've already sent RST in response to STOP_SENDING"); } } -TEST_P(QuicSpdyStreamTest, ProcessOutgoingWebTransportHeaders) { +TEST_P(QuicSpdyStreamTest, ProcessOutgoingWebTransportHeadersDatagramDraft00) { if (!UsesHttp3()) { return; } InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_->EnableWebTransport(); - QuicSpdySessionPeer::EnableWebTransport(*session_); + session_->OnSetting(SETTINGS_ENABLE_CONNECT_PROTOCOL, 1); + QuicSpdySessionPeer::EnableWebTransport(session_.get()); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft00); EXPECT_CALL(*stream_, WriteHeadersMock(false)); EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) @@ -3122,25 +3074,109 @@ TEST_P(QuicSpdyStreamTest, ProcessOutgoingWebTransportHeaders) { spdy::SpdyHeaderBlock headers; headers[":method"] = "CONNECT"; headers[":protocol"] = "webtransport"; - headers["datagram-flow-id"] = absl::StrCat(session_->GetNextDatagramFlowId()); + headers["datagram-flow-id"] = absl::StrCat(stream_->id()); stream_->WriteHeaders(std::move(headers), /*fin=*/false, nullptr); ASSERT_TRUE(stream_->web_transport() != nullptr); EXPECT_EQ(stream_->id(), stream_->web_transport()->id()); } -TEST_P(QuicSpdyStreamTest, ProcessIncomingWebTransportHeaders) { +TEST_P(QuicSpdyStreamTest, ProcessOutgoingWebTransportHeadersDatagramDraft04) { + if (!UsesHttp3()) { + return; + } + + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + session_->EnableWebTransport(); + session_->OnSetting(SETTINGS_ENABLE_CONNECT_PROTOCOL, 1); + QuicSpdySessionPeer::EnableWebTransport(session_.get()); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft04); + + EXPECT_CALL(*stream_, WriteHeadersMock(false)); + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) + .Times(AnyNumber()); + + spdy::SpdyHeaderBlock headers; + headers[":method"] = "CONNECT"; + headers[":protocol"] = "webtransport"; + stream_->WriteHeaders(std::move(headers), /*fin=*/false, nullptr); + ASSERT_TRUE(stream_->web_transport() != nullptr); + EXPECT_EQ(stream_->id(), stream_->web_transport()->id()); +} + +TEST_P(QuicSpdyStreamTest, ProcessIncomingWebTransportHeadersDatagramDraft04) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + session_->EnableWebTransport(); + QuicSpdySessionPeer::EnableWebTransport(session_.get()); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft04); + + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "webtransport"; + headers_["sec-webtransport-http3-draft02"] = "1"; + + stream_->OnStreamHeadersPriority( + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + ProcessHeaders(false, headers_); + stream_->OnCapsule( + Capsule::RegisterDatagramNoContext(DatagramFormatType::WEBTRANSPORT)); + EXPECT_EQ("", stream_->data()); + EXPECT_FALSE(stream_->header_list().empty()); + EXPECT_FALSE(stream_->IsDoneReading()); + ASSERT_TRUE(stream_->web_transport() != nullptr); + EXPECT_EQ(stream_->id(), stream_->web_transport()->id()); +} + +TEST_P(QuicSpdyStreamTest, ProcessIncomingWebTransportHeadersDatagramDraft00) { + if (!UsesHttp3()) { + return; + } + + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + session_->EnableWebTransport(); + QuicSpdySessionPeer::EnableWebTransport(session_.get()); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft00); + + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "webtransport"; + headers_["datagram-flow-id"] = absl::StrCat(stream_->id()); + + stream_->OnStreamHeadersPriority( + spdy::SpdyStreamPrecedence(kV3HighestPriority)); + ProcessHeaders(false, headers_); + EXPECT_EQ("", stream_->data()); + EXPECT_FALSE(stream_->header_list().empty()); + EXPECT_FALSE(stream_->IsDoneReading()); + ASSERT_TRUE(stream_->web_transport() != nullptr); + EXPECT_EQ(stream_->id(), stream_->web_transport()->id()); +} + +TEST_P(QuicSpdyStreamTest, + ProcessIncomingWebTransportHeadersWithMismatchedFlowId) { if (!UsesHttp3()) { return; } + // TODO(b/181256914) Remove this test when we deprecate + // draft-ietf-masque-h3-datagram-00 in favor of later drafts. Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); session_->EnableWebTransport(); - QuicSpdySessionPeer::EnableWebTransport(*session_); + QuicSpdySessionPeer::EnableWebTransport(session_.get()); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft00); headers_[":method"] = "CONNECT"; headers_[":protocol"] = "webtransport"; - headers_["datagram-flow-id"] = - absl::StrCat(session_->GetNextDatagramFlowId()); + headers_["datagram-flow-id"] = "2"; stream_->OnStreamHeadersPriority( spdy::SpdyStreamPrecedence(kV3HighestPriority)); @@ -3152,6 +3188,242 @@ TEST_P(QuicSpdyStreamTest, ProcessIncomingWebTransportHeaders) { EXPECT_EQ(stream_->id(), stream_->web_transport()->id()); } +TEST_P(QuicSpdyStreamTest, GetNextDatagramContextIdClient) { + if (!UsesHttp3()) { + return; + } + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + ::testing::NiceMock visitor; + stream_->RegisterHttp3DatagramRegistrationVisitor(&visitor); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 0u); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 2u); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 4u); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 6u); + stream_->UnregisterHttp3DatagramRegistrationVisitor(); +} + +TEST_P(QuicSpdyStreamTest, GetNextDatagramContextIdServer) { + if (!UsesHttp3()) { + return; + } + InitializeWithPerspective(kShouldProcessData, Perspective::IS_SERVER); + ::testing::NiceMock visitor; + stream_->RegisterHttp3DatagramRegistrationVisitor(&visitor); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 1u); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 3u); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 5u); + EXPECT_EQ(stream_->GetNextDatagramContextId(), 7u); + stream_->UnregisterHttp3DatagramRegistrationVisitor(); +} + +TEST_P(QuicSpdyStreamTest, HttpDatagramRegistrationWithoutContextDraft00) { + if (!UsesHttp3()) { + return; + } + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft00); + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "webtransport"; + headers_["datagram-flow-id"] = absl::StrCat(stream_->id()); + ProcessHeaders(false, headers_); + session_->RegisterHttp3DatagramFlowId(stream_->id(), stream_->id()); + ::testing::NiceMock + h3_datagram_registration_visitor; + SavingHttp3DatagramVisitor h3_datagram_visitor; + absl::optional context_id; + absl::string_view format_additional_data; + ASSERT_EQ(QuicDataWriter::GetVarInt62Len(stream_->id()), 1); + std::array datagram; + datagram[0] = stream_->id(); + for (size_t i = 1; i < datagram.size(); i++) { + datagram[i] = i; + } + stream_->RegisterHttp3DatagramRegistrationVisitor( + &h3_datagram_registration_visitor); + stream_->RegisterHttp3DatagramContextId( + context_id, DatagramFormatType::UDP_PAYLOAD, format_additional_data, + &h3_datagram_visitor); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT(h3_datagram_visitor.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), context_id, + std::string(&datagram[1], datagram.size() - 1)})); + // Test move. + ::testing::NiceMock + h3_datagram_registration_visitor2; + stream_->MoveHttp3DatagramRegistration(&h3_datagram_registration_visitor2); + SavingHttp3DatagramVisitor h3_datagram_visitor2; + stream_->MoveHttp3DatagramContextIdRegistration(context_id, + &h3_datagram_visitor2); + EXPECT_TRUE(h3_datagram_visitor2.received_h3_datagrams().empty()); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT(h3_datagram_visitor2.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), context_id, + std::string(&datagram[1], datagram.size() - 1)})); + // Cleanup. + stream_->UnregisterHttp3DatagramContextId(context_id); + stream_->UnregisterHttp3DatagramRegistrationVisitor(); + session_->UnregisterHttp3DatagramFlowId(stream_->id()); +} + +TEST_P(QuicSpdyStreamTest, H3DatagramRegistrationWithoutContextDraft04) { + if (!UsesHttp3()) { + return; + } + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft04); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft04); + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "webtransport"; + ProcessHeaders(false, headers_); + ::testing::NiceMock + h3_datagram_registration_visitor; + SavingHttp3DatagramVisitor h3_datagram_visitor; + absl::optional context_id; + absl::string_view format_additional_data; + ASSERT_EQ(QuicDataWriter::GetVarInt62Len(stream_->id()), 1); + std::array datagram; + datagram[0] = stream_->id(); + for (size_t i = 1; i < datagram.size(); i++) { + datagram[i] = i; + } + stream_->RegisterHttp3DatagramRegistrationVisitor( + &h3_datagram_registration_visitor); + + // Expect us to send a REGISTER_DATAGRAM_NO_CONTEXT capsule. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) + .Times(AtLeast(1)); + + stream_->RegisterHttp3DatagramContextId( + context_id, DatagramFormatType::UDP_PAYLOAD, format_additional_data, + &h3_datagram_visitor); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT(h3_datagram_visitor.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), context_id, + std::string(&datagram[1], datagram.size() - 1)})); + // Test move. + ::testing::NiceMock + h3_datagram_registration_visitor2; + stream_->MoveHttp3DatagramRegistration(&h3_datagram_registration_visitor2); + SavingHttp3DatagramVisitor h3_datagram_visitor2; + stream_->MoveHttp3DatagramContextIdRegistration(context_id, + &h3_datagram_visitor2); + EXPECT_TRUE(h3_datagram_visitor2.received_h3_datagrams().empty()); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT(h3_datagram_visitor2.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), context_id, + std::string(&datagram[1], datagram.size() - 1)})); + // Cleanup. + stream_->UnregisterHttp3DatagramContextId(context_id); + stream_->UnregisterHttp3DatagramRegistrationVisitor(); +} + +TEST_P(QuicSpdyStreamTest, HttpDatagramRegistrationWithContext) { + if (!UsesHttp3()) { + return; + } + InitializeWithPerspective(kShouldProcessData, Perspective::IS_CLIENT); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft04); + ::testing::NiceMock + h3_datagram_registration_visitor; + SavingHttp3DatagramVisitor h3_datagram_visitor; + absl::optional context_id = 42; + absl::string_view format_additional_data; + ASSERT_EQ(QuicDataWriter::GetVarInt62Len(stream_->id()), 1); + std::array datagram; + datagram[0] = stream_->id(); + datagram[1] = context_id.value(); + for (size_t i = 2; i < datagram.size(); i++) { + datagram[i] = i; + } + stream_->RegisterHttp3DatagramRegistrationVisitor( + &h3_datagram_registration_visitor, /*use_datagram_contexts=*/true); + headers_[":method"] = "CONNECT"; + headers_[":protocol"] = "webtransport"; + headers_["sec-use-datagram-contexts"] = "?1"; + ProcessHeaders(false, headers_); + + // Expect us to send a REGISTER_DATAGRAM_CONTEXT capsule. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) + .Times(AtLeast(1)); + + stream_->RegisterHttp3DatagramContextId( + context_id, DatagramFormatType::UDP_PAYLOAD, format_additional_data, + &h3_datagram_visitor); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT(h3_datagram_visitor.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), context_id, + std::string(&datagram[2], datagram.size() - 2)})); + // Test move. + ::testing::NiceMock + h3_datagram_registration_visitor2; + stream_->MoveHttp3DatagramRegistration(&h3_datagram_registration_visitor2); + SavingHttp3DatagramVisitor h3_datagram_visitor2; + stream_->MoveHttp3DatagramContextIdRegistration(context_id, + &h3_datagram_visitor2); + EXPECT_TRUE(h3_datagram_visitor2.received_h3_datagrams().empty()); + session_->OnMessageReceived( + absl::string_view(datagram.data(), datagram.size())); + EXPECT_THAT(h3_datagram_visitor2.received_h3_datagrams(), + ElementsAre(SavingHttp3DatagramVisitor::SavedHttp3Datagram{ + stream_->id(), context_id, + std::string(&datagram[2], datagram.size() - 2)})); + // Cleanup. + + // Expect us to send a CLOSE_DATAGRAM_CONTEXT capsule. + EXPECT_CALL(*session_, WritevData(stream_->id(), _, _, _, _, _)) + .Times(AtLeast(1)); + stream_->UnregisterHttp3DatagramContextId(context_id); + stream_->UnregisterHttp3DatagramRegistrationVisitor(); + session_->UnregisterHttp3DatagramFlowId(stream_->id()); +} + +TEST_P(QuicSpdyStreamTest, SendHttp3Datagram) { + if (!UsesHttp3()) { + return; + } + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft04); + absl::optional context_id; + std::string h3_datagram_payload = {1, 2, 3, 4, 5, 6}; + EXPECT_CALL(*connection_, SendMessage(1, _, false)) + .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); + EXPECT_EQ(stream_->SendHttp3Datagram(context_id, h3_datagram_payload), + MESSAGE_STATUS_SUCCESS); +} + +TEST_P(QuicSpdyStreamTest, GetMaxDatagramSize) { + if (!UsesHttp3()) { + return; + } + Initialize(kShouldProcessData); + session_->set_local_http_datagram_support(HttpDatagramSupport::kDraft00And04); + QuicSpdySessionPeer::SetHttpDatagramSupport(session_.get(), + HttpDatagramSupport::kDraft04); + + QuicByteCount size = stream_->GetMaxDatagramSize(absl::nullopt); + QuicByteCount size_with_context = + stream_->GetMaxDatagramSize(/*context_id=*/1); + EXPECT_GT(size, 512u); + EXPECT_EQ(size - 1, size_with_context); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/http/spdy_utils.cc b/gquiche/quic/core/http/spdy_utils.cc index 6586c4f3..0c276c0b 100644 --- a/gquiche/quic/core/http/spdy_utils.cc +++ b/gquiche/quic/core/http/spdy_utils.cc @@ -16,8 +16,7 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" #include "gquiche/spdy/core/spdy_protocol.h" using spdy::SpdyHeaderBlock; @@ -78,7 +77,7 @@ bool SpdyUtils::CopyAndValidateHeaders(const QuicHeaderList& header_list, headers->AppendValueOrAddHeader(name, p.second); } - if (QuicContainsKey(*headers, "content-length") && + if (headers->contains("content-length") && !ExtractContentLengthFromHeaders(content_length, headers)) { return false; } @@ -143,7 +142,7 @@ bool SpdyUtils::PopulateHeaderBlockFromUrl(const std::string url, } (*headers)[":scheme"] = url.substr(0, pos); size_t start = pos + 3; - pos = url.find("/", start); + pos = url.find('/', start); if (pos == std::string::npos) { (*headers)[":authority"] = url.substr(start); (*headers)[":path"] = "/"; @@ -155,7 +154,7 @@ bool SpdyUtils::PopulateHeaderBlockFromUrl(const std::string url, } // static -absl::optional SpdyUtils::ParseDatagramFlowIdHeader( +absl::optional SpdyUtils::ParseDatagramFlowIdHeader( const spdy::SpdyHeaderBlock& headers) { auto flow_id_pair = headers.find("datagram-flow-id"); if (flow_id_pair == headers.end()) { @@ -163,7 +162,7 @@ absl::optional SpdyUtils::ParseDatagramFlowIdHeader( } std::vector flow_id_strings = absl::StrSplit(flow_id_pair->second, ','); - absl::optional first_named_flow_id; + absl::optional first_named_flow_id; for (absl::string_view flow_id_string : flow_id_strings) { std::vector flow_id_components = absl::StrSplit(flow_id_string, ';'); @@ -173,7 +172,7 @@ absl::optional SpdyUtils::ParseDatagramFlowIdHeader( absl::string_view flow_id_value_string = flow_id_components[0]; quiche::QuicheTextUtils::RemoveLeadingAndTrailingWhitespace( &flow_id_value_string); - QuicDatagramFlowId flow_id; + QuicDatagramStreamId flow_id; if (!absl::SimpleAtoi(flow_id_value_string, &flow_id)) { continue; } @@ -191,7 +190,7 @@ absl::optional SpdyUtils::ParseDatagramFlowIdHeader( // static void SpdyUtils::AddDatagramFlowIdHeader(spdy::SpdyHeaderBlock* headers, - QuicDatagramFlowId flow_id) { + QuicDatagramStreamId flow_id) { (*headers)["datagram-flow-id"] = absl::StrCat(flow_id); } diff --git a/gquiche/quic/core/http/spdy_utils.h b/gquiche/quic/core/http/spdy_utils.h index 8b612ac5..671161f4 100644 --- a/gquiche/quic/core/http/spdy_utils.h +++ b/gquiche/quic/core/http/spdy_utils.h @@ -55,12 +55,12 @@ class QUIC_EXPORT_PRIVATE SpdyUtils { // Parses the "datagram-flow-id" header, returns the flow ID on success, or // returns absl::nullopt if the header was not present or failed to parse. - static absl::optional ParseDatagramFlowIdHeader( + static absl::optional ParseDatagramFlowIdHeader( const spdy::SpdyHeaderBlock& headers); // Adds the "datagram-flow-id" header. static void AddDatagramFlowIdHeader(spdy::SpdyHeaderBlock* headers, - QuicDatagramFlowId flow_id); + QuicDatagramStreamId flow_id); }; } // namespace quic diff --git a/gquiche/quic/core/http/spdy_utils_test.cc b/gquiche/quic/core/http/spdy_utils_test.cc index ce4cd4ad..f7a212e9 100644 --- a/gquiche/quic/core/http/spdy_utils_test.cc +++ b/gquiche/quic/core/http/spdy_utils_test.cc @@ -9,7 +9,6 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/http/spdy_utils.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using spdy::SpdyHeaderBlock; using testing::Pair; @@ -35,7 +34,7 @@ static std::unique_ptr FromList( static void ValidateDatagramFlowId( const std::string& header_value, - absl::optional expected_flow_id) { + absl::optional expected_flow_id) { SpdyHeaderBlock headers; headers["datagram-flow-id"] = header_value; ASSERT_EQ(SpdyUtils::ParseDatagramFlowIdHeader(headers), expected_flow_id); @@ -392,7 +391,7 @@ TEST_F(DatagramFlowIdTest, DatagramFlowId) { SpdyHeaderBlock headers; EXPECT_EQ(SpdyUtils::ParseDatagramFlowIdHeader(headers), absl::nullopt); // Add header and verify it parses. - QuicDatagramFlowId flow_id = 123; + QuicDatagramStreamId flow_id = 123; SpdyUtils::AddDatagramFlowIdHeader(&headers, flow_id); EXPECT_EQ(SpdyUtils::ParseDatagramFlowIdHeader(headers), flow_id); // Test empty header. diff --git a/gquiche/quic/core/http/web_transport_http3.cc b/gquiche/quic/core/http/web_transport_http3.cc index fcb1aa1e..ddc140bd 100644 --- a/gquiche/quic/core/http/web_transport_http3.cc +++ b/gquiche/quic/core/http/web_transport_http3.cc @@ -4,13 +4,17 @@ #include "gquiche/quic/core/http/web_transport_http3.h" +#include #include #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "gquiche/quic/core/http/capsule.h" #include "gquiche/quic/core/http/quic_spdy_session.h" #include "gquiche/quic/core/http/quic_spdy_stream.h" #include "gquiche/quic/core/quic_data_reader.h" #include "gquiche/quic/core/quic_data_writer.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_stream.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" @@ -25,7 +29,9 @@ namespace quic { namespace { class QUIC_NO_EXPORT NoopWebTransportVisitor : public WebTransportVisitor { - void OnSessionReady() override {} + void OnSessionReady(const spdy::SpdyHeaderBlock&) override {} + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} void OnIncomingBidirectionalStreamAvailable() override {} void OnIncomingUnidirectionalStreamAvailable() override {} void OnDatagramReceived(absl::string_view /*datagram*/) override {} @@ -37,16 +43,23 @@ class QUIC_NO_EXPORT NoopWebTransportVisitor : public WebTransportVisitor { WebTransportHttp3::WebTransportHttp3(QuicSpdySession* session, QuicSpdyStream* connect_stream, WebTransportSessionId id, - QuicDatagramFlowId flow_id) + bool attempt_to_use_datagram_contexts) : session_(session), connect_stream_(connect_stream), id_(id), - flow_id_(flow_id), visitor_(std::make_unique()) { QUICHE_DCHECK(session_->SupportsWebTransport()); QUICHE_DCHECK(IsValidWebTransportSessionId(id, session_->version())); QUICHE_DCHECK_EQ(connect_stream_->id(), id); - session_->RegisterHttp3FlowId(flow_id, this); + connect_stream_->RegisterHttp3DatagramRegistrationVisitor( + this, attempt_to_use_datagram_contexts); + if (session_->perspective() == Perspective::IS_CLIENT) { + context_is_known_ = true; + context_currently_registered_ = true; + if (attempt_to_use_datagram_contexts) { + context_id_ = connect_stream_->GetNextDatagramContextId(); + } + } } void WebTransportHttp3::AssociateStream(QuicStreamId stream_id) { @@ -66,7 +79,7 @@ void WebTransportHttp3::AssociateStream(QuicStreamId stream_id) { } } -void WebTransportHttp3::CloseAllAssociatedStreams() { +void WebTransportHttp3::OnConnectStreamClosing() { // Copy the stream list before iterating over it, as calls to ResetStream() // can potentially mutate the |session_| list. std::vector streams(streams_.begin(), streams_.end()); @@ -74,23 +87,138 @@ void WebTransportHttp3::CloseAllAssociatedStreams() { for (QuicStreamId id : streams) { session_->ResetStream(id, QUIC_STREAM_WEBTRANSPORT_SESSION_GONE); } - session_->UnregisterHttp3FlowId(flow_id_); + if (context_currently_registered_) { + context_currently_registered_ = false; + connect_stream_->UnregisterHttp3DatagramContextId(context_id_); + } + connect_stream_->UnregisterHttp3DatagramRegistrationVisitor(); + + MaybeNotifyClose(); +} + +void WebTransportHttp3::CloseSession(WebTransportSessionError error_code, + absl::string_view error_message) { + if (close_sent_) { + QUIC_BUG(WebTransportHttp3 close sent twice) + << "Calling WebTransportHttp3::CloseSession() more than once is not " + "allowed."; + return; + } + close_sent_ = true; + + // There can be a race between us trying to send our close and peer sending + // one. If we received a close, however, we cannot send ours since we already + // closed the stream in response. + if (close_received_) { + QUIC_DLOG(INFO) << "Not sending CLOSE_WEBTRANSPORT_SESSION as we've " + "already sent one from peer."; + return; + } + + error_code_ = error_code; + error_message_ = std::string(error_message); + QuicConnection::ScopedPacketFlusher flusher( + connect_stream_->spdy_session()->connection()); + connect_stream_->WriteCapsule( + Capsule::CloseWebTransportSession(error_code, error_message), + /*fin=*/true); +} + +void WebTransportHttp3::OnCloseReceived(WebTransportSessionError error_code, + absl::string_view error_message) { + if (close_received_) { + QUIC_BUG(WebTransportHttp3 notified of close received twice) + << "WebTransportHttp3::OnCloseReceived() may be only called once."; + } + close_received_ = true; + + // If the peer has sent a close after we sent our own, keep the local error. + if (close_sent_) { + QUIC_DLOG(INFO) << "Ignoring received CLOSE_WEBTRANSPORT_SESSION as we've " + "already sent our own."; + return; + } + + error_code_ = error_code; + error_message_ = std::string(error_message); + connect_stream_->WriteOrBufferBody("", /*fin=*/true); + MaybeNotifyClose(); +} + +void WebTransportHttp3::OnConnectStreamFinReceived() { + // If we already received a CLOSE_WEBTRANSPORT_SESSION capsule, we don't need + // to do anything about receiving a FIN, since we already sent one in + // response. + if (close_received_) { + return; + } + close_received_ = true; + if (close_sent_) { + QUIC_DLOG(INFO) << "Ignoring received FIN as we've already sent our close."; + return; + } + + connect_stream_->WriteOrBufferBody("", /*fin=*/true); + MaybeNotifyClose(); +} + +void WebTransportHttp3::CloseSessionWithFinOnlyForTests() { + QUICHE_DCHECK(!close_sent_); + close_sent_ = true; + if (close_received_) { + return; + } + + connect_stream_->WriteOrBufferBody("", /*fin=*/true); } void WebTransportHttp3::HeadersReceived(const spdy::SpdyHeaderBlock& headers) { if (session_->perspective() == Perspective::IS_CLIENT) { - auto it = headers.find(":status"); - if (it == headers.end() || it->second != "200") { + int status_code; + if (!QuicSpdyStream::ParseHeaderStatusCode(headers, &status_code)) { QUIC_DVLOG(1) << ENDPOINT << "Received WebTransport headers from server without " - "status 200, rejecting."; + "a valid status code, rejecting."; + rejection_reason_ = WebTransportHttp3RejectionReason::kNoStatusCode; return; } + bool valid_status = status_code >= 200 && status_code <= 299; + if (!valid_status) { + QUIC_DVLOG(1) << ENDPOINT + << "Received WebTransport headers from server with " + "status code " + << status_code << ", rejecting."; + rejection_reason_ = WebTransportHttp3RejectionReason::kWrongStatusCode; + return; + } + bool should_validate_version = + session_->http_datagram_support() != HttpDatagramSupport::kDraft00 && + session_->ShouldValidateWebTransportVersion(); + if (should_validate_version) { + auto draft_version_it = headers.find("sec-webtransport-http3-draft"); + if (draft_version_it == headers.end()) { + QUIC_DVLOG(1) << ENDPOINT + << "Received WebTransport headers from server without " + "a draft version, rejecting."; + rejection_reason_ = + WebTransportHttp3RejectionReason::kMissingDraftVersion; + return; + } + if (draft_version_it->second != "draft02") { + QUIC_DVLOG(1) << ENDPOINT + << "Received WebTransport headers from server with " + "an unknown draft version (" + << draft_version_it->second << "), rejecting."; + rejection_reason_ = + WebTransportHttp3RejectionReason::kUnsupportedDraftVersion; + return; + } + } } QUIC_DVLOG(1) << ENDPOINT << "WebTransport session " << id_ << " ready."; ready_ = true; - visitor_->OnSessionReady(); + visitor_->OnSessionReady(headers); session_->ProcessBufferedWebTransportStreamsForSession(this); } @@ -153,33 +281,118 @@ WebTransportStream* WebTransportHttp3::OpenOutgoingUnidirectionalStream() { } MessageStatus WebTransportHttp3::SendOrQueueDatagram(QuicMemSlice datagram) { - return session_->SendHttp3Datagram( - flow_id_, absl::string_view(datagram.data(), datagram.length())); + return connect_stream_->SendHttp3Datagram( + context_id_, absl::string_view(datagram.data(), datagram.length())); +} + +QuicByteCount WebTransportHttp3::GetMaxDatagramSize() const { + return connect_stream_->GetMaxDatagramSize(context_id_); } void WebTransportHttp3::SetDatagramMaxTimeInQueue( QuicTime::Delta max_time_in_queue) { - session_->SetMaxTimeInQueueForFlowId(flow_id_, max_time_in_queue); + connect_stream_->SetMaxDatagramTimeInQueue(max_time_in_queue); } -void WebTransportHttp3::OnHttp3Datagram(QuicDatagramFlowId flow_id, - absl::string_view payload) { - QUICHE_DCHECK_EQ(flow_id, flow_id_); +void WebTransportHttp3::OnHttp3Datagram( + QuicStreamId stream_id, absl::optional context_id, + absl::string_view payload) { + QUICHE_DCHECK_EQ(stream_id, connect_stream_->id()); + QUICHE_DCHECK(context_id == context_id_); visitor_->OnDatagramReceived(payload); } +void WebTransportHttp3::OnContextReceived( + QuicStreamId stream_id, absl::optional context_id, + DatagramFormatType format_type, absl::string_view format_additional_data) { + if (stream_id != connect_stream_->id()) { + QUIC_BUG(WT3 bad datagram context registration) + << ENDPOINT << "Registered stream ID " << stream_id << ", expected " + << connect_stream_->id(); + return; + } + if (format_type != DatagramFormatType::WEBTRANSPORT) { + QUIC_DLOG(INFO) << ENDPOINT << "Ignoring unexpected datagram format type " + << DatagramFormatTypeToString(format_type); + return; + } + if (!format_additional_data.empty()) { + QUIC_DLOG(ERROR) + << ENDPOINT + << "Received non-empty format additional data for context ID " + << (context_id_.has_value() ? context_id_.value() : 0) + << " on stream ID " << connect_stream_->id(); + session_->ResetStream(connect_stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD); + return; + } + if (!context_is_known_) { + context_is_known_ = true; + context_id_ = context_id; + } + if (context_id != context_id_) { + QUIC_DLOG(INFO) << ENDPOINT << "Ignoring unexpected context ID " + << (context_id.has_value() ? context_id.value() : 0) + << " instead of " + << (context_id_.has_value() ? context_id_.value() : 0) + << " on stream ID " << connect_stream_->id(); + return; + } + if (session_->perspective() == Perspective::IS_SERVER) { + if (context_currently_registered_) { + QUIC_DLOG(ERROR) << ENDPOINT << "Received duplicate context ID " + << (context_id_.has_value() ? context_id_.value() : 0) + << " on stream ID " << connect_stream_->id(); + session_->ResetStream(connect_stream_->id(), QUIC_STREAM_CANCELLED); + return; + } + context_currently_registered_ = true; + connect_stream_->RegisterHttp3DatagramContextId( + context_id_, format_type, format_additional_data, this); + } +} + +void WebTransportHttp3::OnContextClosed( + QuicStreamId stream_id, absl::optional context_id, + ContextCloseCode close_code, absl::string_view close_details) { + if (stream_id != connect_stream_->id()) { + QUIC_BUG(WT3 bad datagram context registration) + << ENDPOINT << "Closed context on stream ID " << stream_id + << ", expected " << connect_stream_->id(); + return; + } + if (context_id != context_id_) { + QUIC_DLOG(INFO) << ENDPOINT << "Ignoring unexpected close of context ID " + << (context_id.has_value() ? context_id.value() : 0) + << " instead of " + << (context_id_.has_value() ? context_id_.value() : 0) + << " on stream ID " << connect_stream_->id(); + return; + } + QUIC_DLOG(INFO) << ENDPOINT + << "Received datagram context close with close code " + << close_code << " close details \"" << close_details + << "\" on stream ID " << connect_stream_->id() + << ", resetting stream"; + session_->ResetStream(connect_stream_->id(), QUIC_BAD_APPLICATION_PAYLOAD); +} + +void WebTransportHttp3::MaybeNotifyClose() { + if (close_notified_) { + return; + } + close_notified_ = true; + visitor_->OnSessionClosed(error_code_, error_message_); +} + WebTransportHttp3UnidirectionalStream::WebTransportHttp3UnidirectionalStream( - PendingStream* pending, - QuicSpdySession* session) - : QuicStream(pending, session, READ_UNIDIRECTIONAL, /*is_static=*/false), + PendingStream* pending, QuicSpdySession* session) + : QuicStream(pending, session, /*is_static=*/false), session_(session), adapter_(session, this, sequencer()), needs_to_send_preamble_(false) {} WebTransportHttp3UnidirectionalStream::WebTransportHttp3UnidirectionalStream( - QuicStreamId id, - QuicSpdySession* session, - WebTransportSessionId session_id) + QuicStreamId id, QuicSpdySession* session, WebTransportSessionId session_id) : QuicStream(id, session, /*is_static=*/false, WRITE_UNIDIRECTIONAL), session_(session), adapter_(session, this, sequencer()), @@ -268,4 +481,65 @@ void WebTransportHttp3UnidirectionalStream::OnClose() { session->OnStreamClosed(id()); } +void WebTransportHttp3UnidirectionalStream::OnStreamReset( + const QuicRstStreamFrame& frame) { + if (adapter_.visitor() != nullptr) { + adapter_.visitor()->OnResetStreamReceived( + Http3ErrorToWebTransportOrDefault(frame.ietf_error_code)); + } + QuicStream::OnStreamReset(frame); +} +bool WebTransportHttp3UnidirectionalStream::OnStopSending( + QuicResetStreamError error) { + if (adapter_.visitor() != nullptr) { + adapter_.visitor()->OnStopSendingReceived( + Http3ErrorToWebTransportOrDefault(error.ietf_application_code())); + } + return QuicStream::OnStopSending(error); +} +void WebTransportHttp3UnidirectionalStream::OnWriteSideInDataRecvdState() { + if (adapter_.visitor() != nullptr) { + adapter_.visitor()->OnWriteSideInDataRecvdState(); + } + + QuicStream::OnWriteSideInDataRecvdState(); +} + +namespace { +constexpr uint64_t kWebTransportMappedErrorCodeFirst = 0x52e4a40fa8db; +constexpr uint64_t kWebTransportMappedErrorCodeLast = 0x52e4a40fa9e2; +constexpr WebTransportStreamError kDefaultWebTransportError = 0; +} // namespace + +absl::optional Http3ErrorToWebTransport( + uint64_t http3_error_code) { + // Ensure the code is within the valid range. + if (http3_error_code < kWebTransportMappedErrorCodeFirst || + http3_error_code > kWebTransportMappedErrorCodeLast) { + return absl::nullopt; + } + // Exclude GREASE codepoints. + if ((http3_error_code - 0x21) % 0x1f == 0) { + return absl::nullopt; + } + + uint64_t shifted = http3_error_code - kWebTransportMappedErrorCodeFirst; + uint64_t result = shifted - shifted / 0x1f; + QUICHE_DCHECK_LE(result, std::numeric_limits::max()); + return result; +} + +WebTransportStreamError Http3ErrorToWebTransportOrDefault( + uint64_t http3_error_code) { + absl::optional result = + Http3ErrorToWebTransport(http3_error_code); + return result.has_value() ? *result : kDefaultWebTransportError; +} + +uint64_t WebTransportErrorToHttp3( + WebTransportStreamError webtransport_error_code) { + return kWebTransportMappedErrorCodeFirst + webtransport_error_code + + webtransport_error_code / 0x1e; +} + } // namespace quic diff --git a/gquiche/quic/core/http/web_transport_http3.h b/gquiche/quic/core/http/web_transport_http3.h index 0695056d..14dc37a8 100644 --- a/gquiche/quic/core/http/web_transport_http3.h +++ b/gquiche/quic/core/http/web_transport_http3.h @@ -7,13 +7,15 @@ #include +#include "absl/base/attributes.h" #include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" #include "gquiche/quic/core/http/quic_spdy_session.h" +#include "gquiche/quic/core/http/web_transport_stream_adapter.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_stream.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/web_transport_interface.h" -#include "gquiche/quic/core/web_transport_stream_adapter.h" #include "gquiche/spdy/core/spdy_header_block.h" namespace quic { @@ -21,6 +23,14 @@ namespace quic { class QuicSpdySession; class QuicSpdyStream; +enum class WebTransportHttp3RejectionReason { + kNone, + kNoStatusCode, + kWrongStatusCode, + kMissingDraftVersion, + kUnsupportedDraftVersion, +}; + // A session of WebTransport over HTTP/3. The session is owned by // QuicSpdyStream object for the CONNECT stream that established it. // @@ -28,12 +38,12 @@ class QuicSpdyStream; // class QUIC_EXPORT_PRIVATE WebTransportHttp3 : public WebTransportSession, - public QuicSpdySession::Http3DatagramVisitor { + public QuicSpdyStream::Http3DatagramRegistrationVisitor, + public QuicSpdyStream::Http3DatagramVisitor { public: - WebTransportHttp3(QuicSpdySession* session, - QuicSpdyStream* connect_stream, + WebTransportHttp3(QuicSpdySession* session, QuicSpdyStream* connect_stream, WebTransportSessionId id, - QuicDatagramFlowId flow_id); + bool attempt_to_use_datagram_contexts); void HeadersReceived(const spdy::SpdyHeaderBlock& headers); void SetVisitor(std::unique_ptr visitor) { @@ -42,13 +52,27 @@ class QUIC_EXPORT_PRIVATE WebTransportHttp3 WebTransportSessionId id() { return id_; } bool ready() { return ready_; } + absl::optional context_id() const { + return context_id_; + } void AssociateStream(QuicStreamId stream_id); void OnStreamClosed(QuicStreamId stream_id) { streams_.erase(stream_id); } - void CloseAllAssociatedStreams(); + void OnConnectStreamClosing(); size_t NumberOfAssociatedStreams() { return streams_.size(); } + void CloseSession(WebTransportSessionError error_code, + absl::string_view error_message) override; + void OnCloseReceived(WebTransportSessionError error_code, + absl::string_view error_message); + void OnConnectStreamFinReceived(); + + // It is legal for WebTransport to be closed without a + // CLOSE_WEBTRANSPORT_SESSION capsule. We always send a capsule, but we still + // need to ensure we handle this case correctly. + void CloseSessionWithFinOnlyForTests(); + // Return the earliest incoming stream that has been received by the session // but has not been accepted. Returns nullptr if there are no incoming // streams. @@ -61,22 +85,61 @@ class QUIC_EXPORT_PRIVATE WebTransportHttp3 WebTransportStream* OpenOutgoingUnidirectionalStream() override; MessageStatus SendOrQueueDatagram(QuicMemSlice datagram) override; + QuicByteCount GetMaxDatagramSize() const override; void SetDatagramMaxTimeInQueue(QuicTime::Delta max_time_in_queue) override; - void OnHttp3Datagram(QuicDatagramFlowId flow_id, + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::optional context_id, absl::string_view payload) override; + // From QuicSpdyStream::Http3DatagramRegistrationVisitor. + void OnContextReceived(QuicStreamId stream_id, + absl::optional context_id, + DatagramFormatType format_type, + absl::string_view format_additional_data) override; + void OnContextClosed(QuicStreamId stream_id, + absl::optional context_id, + ContextCloseCode close_code, + absl::string_view close_details) override; + + bool close_received() const { return close_received_; } + WebTransportHttp3RejectionReason rejection_reason() const { + return rejection_reason_; + } + private: + // Notifies the visitor that the connection has been closed. Ensures that the + // visitor is only ever called once. + void MaybeNotifyClose(); + QuicSpdySession* const session_; // Unowned. QuicSpdyStream* const connect_stream_; // Unowned. const WebTransportSessionId id_; - const QuicDatagramFlowId flow_id_; + absl::optional context_id_; // |ready_| is set to true when the peer has seen both sets of headers. bool ready_ = false; + // Whether we know which |context_id_| to use. On the client this is always + // true, and on the server it becomes true when we receive a context + // registration capsule. + bool context_is_known_ = false; + // Whether |context_id_| is currently registered with |connect_stream_|. + bool context_currently_registered_ = false; std::unique_ptr visitor_; absl::flat_hash_set streams_; - QuicCircularDeque incoming_bidirectional_streams_; - QuicCircularDeque incoming_unidirectional_streams_; + quiche::QuicheCircularDeque incoming_bidirectional_streams_; + quiche::QuicheCircularDeque incoming_unidirectional_streams_; + + bool close_sent_ = false; + bool close_received_ = false; + bool close_notified_ = false; + + WebTransportHttp3RejectionReason rejection_reason_ = + WebTransportHttp3RejectionReason::kNone; + // Those are set to default values, which are used if the session is not + // closed cleanly using an appropriate capsule. + WebTransportSessionError error_code_ = 0; + std::string error_message_ = ""; }; class QUIC_EXPORT_PRIVATE WebTransportHttp3UnidirectionalStream @@ -97,6 +160,9 @@ class QUIC_EXPORT_PRIVATE WebTransportHttp3UnidirectionalStream void OnDataAvailable() override; void OnCanWriteNewData() override; void OnClose() override; + void OnStreamReset(const QuicRstStreamFrame& frame) override; + bool OnStopSending(QuicResetStreamError error) override; + void OnWriteSideInDataRecvdState() override; WebTransportStream* interface() { return &adapter_; } void SetUnblocked() { sequencer()->SetUnblocked(); } @@ -112,6 +178,20 @@ class QUIC_EXPORT_PRIVATE WebTransportHttp3UnidirectionalStream void MaybeCloseIncompleteStream(); }; +// Remaps HTTP/3 error code into a WebTransport error code. Returns nullopt if +// the provided code is outside of valid range. +QUIC_EXPORT_PRIVATE absl::optional +Http3ErrorToWebTransport(uint64_t http3_error_code); + +// Same as above, but returns default error value (zero) when none could be +// mapped. +QUIC_EXPORT_PRIVATE WebTransportStreamError +Http3ErrorToWebTransportOrDefault(uint64_t http3_error_code); + +// Remaps WebTransport error code into an HTTP/3 error code. +QUIC_EXPORT_PRIVATE uint64_t +WebTransportErrorToHttp3(WebTransportStreamError webtransport_error_code); + } // namespace quic #endif // QUICHE_QUIC_CORE_HTTP_WEB_TRANSPORT_HTTP3_H_ diff --git a/gquiche/quic/core/http/web_transport_http3_test.cc b/gquiche/quic/core/http/web_transport_http3_test.cc new file mode 100644 index 00000000..9e0a8607 --- /dev/null +++ b/gquiche/quic/core/http/web_transport_http3_test.cc @@ -0,0 +1,52 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/http/web_transport_http3.h" + +#include +#include + +#include "absl/types/optional.h" +#include "gquiche/quic/platform/api/quic_test.h" + +namespace quic { +namespace { + +using ::testing::Optional; + +TEST(WebTransportHttp3Test, ErrorCodesToHttp3) { + EXPECT_EQ(0x52e4a40fa8dbu, WebTransportErrorToHttp3(0x00)); + EXPECT_EQ(0x52e4a40fa9e2u, WebTransportErrorToHttp3(0xff)); + + EXPECT_EQ(0x52e4a40fa8f7u, WebTransportErrorToHttp3(0x1c)); + EXPECT_EQ(0x52e4a40fa8f8u, WebTransportErrorToHttp3(0x1d)); + // 0x52e4a40fa8f9 is a GREASE codepoint + EXPECT_EQ(0x52e4a40fa8fau, WebTransportErrorToHttp3(0x1e)); +} + +TEST(WebTransportHttp3Test, ErrorCodesToWebTransport) { + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8db), Optional(0x00)); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa9e2), Optional(0xff)); + + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8f7), Optional(0x1cu)); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8f8), Optional(0x1du)); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8f9), absl::nullopt); + EXPECT_THAT(Http3ErrorToWebTransport(0x52e4a40fa8fa), Optional(0x1eu)); + + EXPECT_EQ(Http3ErrorToWebTransport(0), absl::nullopt); + EXPECT_EQ(Http3ErrorToWebTransport(std::numeric_limits::max()), + absl::nullopt); +} + +TEST(WebTransportHttp3Test, ErrorCodeRoundTrip) { + for (int error = 0; error < 256; error++) { + uint64_t http_error = WebTransportErrorToHttp3(error); + absl::optional mapped_back = + quic::Http3ErrorToWebTransport(http_error); + EXPECT_THAT(mapped_back, Optional(error)); + } +} + +} // namespace +} // namespace quic diff --git a/gquiche/quic/core/web_transport_stream_adapter.cc b/gquiche/quic/core/http/web_transport_stream_adapter.cc similarity index 80% rename from gquiche/quic/core/web_transport_stream_adapter.cc rename to gquiche/quic/core/http/web_transport_stream_adapter.cc index bfc70e53..133f30a4 100644 --- a/gquiche/quic/core/web_transport_stream_adapter.cc +++ b/gquiche/quic/core/http/web_transport_stream_adapter.cc @@ -2,7 +2,10 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "gquiche/quic/core/web_transport_stream_adapter.h" +#include "gquiche/quic/core/http/web_transport_stream_adapter.h" + +#include "gquiche/quic/core/http/web_transport_http3.h" +#include "gquiche/quic/core/quic_error_codes.h" namespace quic { @@ -42,13 +45,10 @@ bool WebTransportStreamAdapter::Write(absl::string_view data) { return false; } - QuicUniqueBufferPtr buffer = MakeUniqueBuffer( - session_->connection()->helper()->GetStreamSendBufferAllocator(), - data.size()); - memcpy(buffer.get(), data.data(), data.size()); - QuicMemSlice memslice(std::move(buffer), data.size()); + QuicMemSlice memslice(QuicBuffer::Copy( + session_->connection()->helper()->GetStreamSendBufferAllocator(), data)); QuicConsumedData consumed = - stream_->WriteMemSlices(QuicMemSliceSpan(&memslice), /*fin=*/false); + stream_->WriteMemSlices(absl::MakeSpan(&memslice, 1), /*fin=*/false); if (consumed.bytes_consumed == data.size()) { return true; @@ -78,7 +78,7 @@ bool WebTransportStreamAdapter::SendFin() { QuicMemSlice empty; QuicConsumedData consumed = - stream_->WriteMemSlices(QuicMemSliceSpan(&empty), /*fin=*/true); + stream_->WriteMemSlices(absl::MakeSpan(&empty, 1), /*fin=*/true); QUICHE_DCHECK_EQ(consumed.bytes_consumed, 0u); return consumed.fin_consumed; } @@ -113,4 +113,15 @@ void WebTransportStreamAdapter::OnCanWriteNewData() { } } +void WebTransportStreamAdapter::ResetWithUserCode( + WebTransportStreamError error) { + stream_->ResetWriteSide(QuicResetStreamError( + QUIC_STREAM_CANCELLED, WebTransportErrorToHttp3(error))); +} + +void WebTransportStreamAdapter::SendStopSending(WebTransportStreamError error) { + stream_->SendStopSending(QuicResetStreamError( + QUIC_STREAM_CANCELLED, WebTransportErrorToHttp3(error))); +} + } // namespace quic diff --git a/gquiche/quic/core/web_transport_stream_adapter.h b/gquiche/quic/core/http/web_transport_stream_adapter.h similarity index 93% rename from gquiche/quic/core/web_transport_stream_adapter.h rename to gquiche/quic/core/http/web_transport_stream_adapter.h index 46d1c5a1..31a759ad 100644 --- a/gquiche/quic/core/web_transport_stream_adapter.h +++ b/gquiche/quic/core/http/web_transport_stream_adapter.h @@ -8,6 +8,7 @@ #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_stream.h" #include "gquiche/quic/core/quic_stream_sequencer.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/web_transport_interface.h" namespace quic { @@ -34,12 +35,11 @@ class QUIC_EXPORT_PRIVATE WebTransportStreamAdapter } QuicStreamId GetStreamId() const override { return stream_->id(); } - void ResetWithUserCode(QuicRstStreamErrorCode error) override { - stream_->Reset(error); - } + void ResetWithUserCode(WebTransportStreamError error) override; void ResetDueToInternalError() override { stream_->Reset(QUIC_STREAM_INTERNAL_ERROR); } + void SendStopSending(WebTransportStreamError error) override; void MaybeResetDueToStreamObjectGone() override { if (stream_->write_side_closed() && stream_->read_side_closed()) { return; diff --git a/gquiche/quic/core/legacy_quic_stream_id_manager.cc b/gquiche/quic/core/legacy_quic_stream_id_manager.cc index 8564cd5c..fb12ffb6 100644 --- a/gquiche/quic/core/legacy_quic_stream_id_manager.cc +++ b/gquiche/quic/core/legacy_quic_stream_id_manager.cc @@ -7,7 +7,6 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/quic_versions.h" -#include "gquiche/quic/platform/api/quic_map_util.h" namespace quic { @@ -125,7 +124,7 @@ bool LegacyQuicStreamIdManager::IsAvailableStream(QuicStreamId id) const { return largest_peer_created_stream_id_ == QuicUtils::GetInvalidStreamId(transport_version_) || id > largest_peer_created_stream_id_ || - QuicContainsKey(available_streams_, id); + available_streams_.contains(id); } bool LegacyQuicStreamIdManager::IsIncomingStream(QuicStreamId id) const { diff --git a/gquiche/quic/core/packet_number_indexed_queue.h b/gquiche/quic/core/packet_number_indexed_queue.h index 6f4b3d18..7ea1a81a 100644 --- a/gquiche/quic/core/packet_number_indexed_queue.h +++ b/gquiche/quic/core/packet_number_indexed_queue.h @@ -5,11 +5,11 @@ #ifndef QUICHE_QUIC_CORE_PACKET_NUMBER_INDEXED_QUEUE_H_ #define QUICHE_QUIC_CORE_PACKET_NUMBER_INDEXED_QUEUE_H_ -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_constants.h" #include "gquiche/quic/core/quic_packet_number.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -115,7 +115,7 @@ class QUIC_NO_EXPORT PacketNumberIndexedQueue { return const_cast(const_this->GetEntryWrapper(offset)); } - QuicCircularDeque entries_; + quiche::QuicheCircularDeque entries_; // NOTE(wub): When --quic_bw_sampler_remove_packets_once_per_congestion_event // is enabled, |number_of_present_entries_| only represents number of holes, // which does not include number of acked or lost packets. diff --git a/gquiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc b/gquiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc index 524d26f8..7c9778db 100644 --- a/gquiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc +++ b/gquiche/quic/core/qpack/fuzzer/qpack_decoder_fuzzer.cc @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include + #include #include #include @@ -9,7 +11,7 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/qpack/qpack_decoder.h" -#include "gquiche/quic/platform/api/quic_fuzzed_data_provider.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" @@ -61,7 +63,8 @@ class HeadersHandler : public QpackProgressiveDecoder::HeadersHandlerInterface { QUICHE_CHECK_EQ(1u, result); } - void OnDecodingErrorDetected(absl::string_view /*error_message*/) override { + void OnDecodingErrorDetected(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override { *error_detected_ = true; } @@ -77,7 +80,7 @@ class HeadersHandler : public QpackProgressiveDecoder::HeadersHandlerInterface { // different ways, so the output could not be expected to match the original // input. extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - QuicFuzzedDataProvider provider(data, size); + FuzzedDataProvider provider(data, size); // Maximum 256 byte dynamic table. Such a small size helps test draining // entries and eviction. diff --git a/gquiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc b/gquiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc new file mode 100644 index 00000000..4274ec8a --- /dev/null +++ b/gquiche/quic/core/qpack/fuzzer/qpack_decoder_stream_receiver_fuzzer.cc @@ -0,0 +1,62 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "gquiche/quic/core/qpack/qpack_decoder_stream_receiver.h" +#include "gquiche/quic/core/quic_error_codes.h" +#include "gquiche/quic/core/quic_stream.h" + +namespace quic { +namespace test { +namespace { + +// A QpackDecoderStreamReceiver::Delegate implementation that ignores all +// decoded instructions but keeps track of whether an error has been detected. +class NoOpDelegate : public QpackDecoderStreamReceiver::Delegate { + public: + NoOpDelegate() : error_detected_(false) {} + ~NoOpDelegate() override = default; + + void OnInsertCountIncrement(uint64_t /*increment*/) override {} + void OnHeaderAcknowledgement(QuicStreamId /*stream_id*/) override {} + void OnStreamCancellation(QuicStreamId /*stream_id*/) override {} + void OnErrorDetected(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override { + error_detected_ = true; + } + + bool error_detected() const { return error_detected_; } + + private: + bool error_detected_; +}; + +} // namespace + +// This fuzzer exercises QpackDecoderStreamReceiver. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + NoOpDelegate delegate; + QpackDecoderStreamReceiver receiver(&delegate); + + FuzzedDataProvider provider(data, size); + + while (!delegate.error_detected() && provider.remaining_bytes() != 0) { + // Process up to 64 kB fragments at a time. Too small upper bound might not + // provide enough coverage, too large might make fuzzing too inefficient. + size_t fragment_size = provider.ConsumeIntegralInRange( + 0, std::numeric_limits::max()); + receiver.Decode(provider.ConsumeRandomLengthString(fragment_size)); + } + + return 0; +} + +} // namespace test +} // namespace quic diff --git a/gquiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc b/gquiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc new file mode 100644 index 00000000..fb459a69 --- /dev/null +++ b/gquiche/quic/core/qpack/fuzzer/qpack_decoder_stream_sender_fuzzer.cc @@ -0,0 +1,54 @@ +// Copyright 2018 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include + +#include +#include + +#include "gquiche/quic/core/qpack/qpack_decoder_stream_sender.h" +#include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" + +namespace quic { +namespace test { + +// This fuzzer exercises QpackDecoderStreamSender. +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + NoopQpackStreamSenderDelegate delegate; + QpackDecoderStreamSender sender; + sender.set_qpack_stream_sender_delegate(&delegate); + + FuzzedDataProvider provider(data, size); + + while (provider.remaining_bytes() != 0) { + switch (provider.ConsumeIntegral() % 4) { + case 0: { + uint64_t increment = provider.ConsumeIntegral(); + sender.SendInsertCountIncrement(increment); + break; + } + case 1: { + QuicStreamId stream_id = provider.ConsumeIntegral(); + sender.SendHeaderAcknowledgement(stream_id); + break; + } + case 2: { + QuicStreamId stream_id = provider.ConsumeIntegral(); + sender.SendStreamCancellation(stream_id); + break; + } + case 3: { + sender.Flush(); + break; + } + } + } + + sender.Flush(); + return 0; +} + +} // namespace test +} // namespace quic diff --git a/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc b/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc index f96641cc..e1f06416 100644 --- a/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc +++ b/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_receiver_fuzzer.cc @@ -2,13 +2,13 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -#include "gquiche/quic/core/qpack/qpack_encoder_stream_receiver.h" +#include #include #include #include "absl/strings/string_view.h" -#include "gquiche/quic/platform/api/quic_fuzzed_data_provider.h" +#include "gquiche/quic/core/qpack/qpack_encoder_stream_receiver.h" #include "gquiche/quic/platform/api/quic_logging.h" namespace quic { @@ -51,13 +51,13 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { NoOpDelegate delegate; QpackEncoderStreamReceiver receiver(&delegate); - QuicFuzzedDataProvider provider(data, size); + FuzzedDataProvider provider(data, size); while (!delegate.error_detected() && provider.remaining_bytes() != 0) { // Process up to 64 kB fragments at a time. Too small upper bound might not // provide enough coverage, too large might make fuzzing too inefficient. size_t fragment_size = provider.ConsumeIntegralInRange( - 1, std::numeric_limits::max()); + 0, std::numeric_limits::max()); receiver.Decode(provider.ConsumeRandomLengthString(fragment_size)); } diff --git a/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc b/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc index 4373e7f5..57675172 100644 --- a/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc +++ b/gquiche/quic/core/qpack/fuzzer/qpack_encoder_stream_sender_fuzzer.cc @@ -2,13 +2,14 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include + #include #include #include #include #include "gquiche/quic/core/qpack/qpack_encoder_stream_sender.h" -#include "gquiche/quic/platform/api/quic_fuzzed_data_provider.h" #include "gquiche/quic/test_tools/qpack/qpack_encoder_test_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" @@ -16,20 +17,17 @@ namespace quic { namespace test { // This fuzzer exercises QpackEncoderStreamSender. -// TODO(bnc): Encoded data could be fed into QpackEncoderStreamReceiver and -// decoded instructions directly compared to input. Figure out how to get gMock -// enabled for cc_fuzz_target target types. extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { NoopQpackStreamSenderDelegate delegate; QpackEncoderStreamSender sender; sender.set_qpack_stream_sender_delegate(&delegate); - QuicFuzzedDataProvider provider(data, size); + FuzzedDataProvider provider(data, size); // Limit string literal length to 2 kB for efficiency. const uint16_t kMaxStringLength = 2048; while (provider.remaining_bytes() != 0) { - switch (provider.ConsumeIntegral() % 4) { + switch (provider.ConsumeIntegral() % 5) { case 0: { bool is_static = provider.ConsumeBool(); uint64_t name_index = provider.ConsumeIntegral(); @@ -60,9 +58,14 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { sender.SendSetDynamicTableCapacity(capacity); break; } + case 4: { + sender.Flush(); + break; + } } } + sender.Flush(); return 0; } diff --git a/gquiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc b/gquiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc index fc86e6f9..42cce225 100644 --- a/gquiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc +++ b/gquiche/quic/core/qpack/fuzzer/qpack_round_trip_fuzzer.cc @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include + #include #include #include @@ -15,15 +17,28 @@ #include "gquiche/quic/core/qpack/qpack_encoder.h" #include "gquiche/quic/core/qpack/qpack_stream_sender_delegate.h" #include "gquiche/quic/core/qpack/value_splitting_header_list.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_error_codes.h" -#include "gquiche/quic/platform/api/quic_fuzzed_data_provider.h" #include "gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_encoder_peer.h" +#include "gquiche/common/quiche_circular_deque.h" #include "gquiche/spdy/core/spdy_header_block.h" namespace quic { namespace test { +namespace { + +// Find the first occurrence of invalid characters NUL, LF, CR in |*value| and +// remove that and the remaining of the string. +void TruncateValueOnInvalidChars(std::string* value) { + for (auto it = value->begin(); it != value->end(); ++it) { + if (*it == '\0' || *it == '\n' || *it == '\r') { + value->erase(it, value->end()); + return; + } + } +} + +} // anonymous namespace // Class to hold QpackEncoder and its DecoderStreamErrorDelegate. class EncodingEndpoint { @@ -99,8 +114,7 @@ class DelayedHeaderBlockTransmitter { virtual void OnHeaderBlockEnd(QuicStreamId stream_id) = 0; }; - DelayedHeaderBlockTransmitter(Visitor* visitor, - QuicFuzzedDataProvider* provider) + DelayedHeaderBlockTransmitter(Visitor* visitor, FuzzedDataProvider* provider) : visitor_(visitor), provider_(provider) {} ~DelayedHeaderBlockTransmitter() { QUICHE_CHECK(header_blocks_.empty()); } @@ -227,7 +241,7 @@ class DelayedHeaderBlockTransmitter { }; Visitor* const visitor_; - QuicFuzzedDataProvider* const provider_; + FuzzedDataProvider* const provider_; std::map> header_blocks_; }; @@ -278,8 +292,10 @@ class VerifyingDecoder : public QpackDecodedHeadersAccumulator::Visitor { visitor_->OnHeaderBlockDecoded(stream_id_); } - void OnHeaderDecodingError(absl::string_view error_message) override { - QUICHE_CHECK(false) << error_message; + void OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) override { + QUICHE_CHECK(false) << QuicErrorCodeToString(error_code) << " " + << error_message; } void Decode(absl::string_view data) { accumulator_.Decode(data); } @@ -406,7 +422,7 @@ class DecodingEndpoint : public DelayedHeaderBlockTransmitter::Visitor, class DelayedStreamDataTransmitter : public QpackStreamSenderDelegate { public: DelayedStreamDataTransmitter(QpackStreamReceiver* receiver, - QuicFuzzedDataProvider* provider) + FuzzedDataProvider* provider) : receiver_(receiver), provider_(provider) {} ~DelayedStreamDataTransmitter() { QUICHE_CHECK(stream_data.empty()); } @@ -436,12 +452,12 @@ class DelayedStreamDataTransmitter : public QpackStreamSenderDelegate { private: QpackStreamReceiver* const receiver_; - QuicFuzzedDataProvider* const provider_; - QuicCircularDeque stream_data; + FuzzedDataProvider* const provider_; + quiche::QuicheCircularDeque stream_data; }; // Generate header list using fuzzer data. -spdy::Http2HeaderBlock GenerateHeaderList(QuicFuzzedDataProvider* provider) { +spdy::Http2HeaderBlock GenerateHeaderList(FuzzedDataProvider* provider) { spdy::Http2HeaderBlock header_list; uint8_t header_count = provider->ConsumeIntegral(); for (uint8_t header_index = 0; header_index < header_count; ++header_index) { @@ -509,6 +525,7 @@ spdy::Http2HeaderBlock GenerateHeaderList(QuicFuzzedDataProvider* provider) { // Header name not in the static table, fuzzed header value. name = "foo"; value = provider->ConsumeRandomLengthString(128); + TruncateValueOnInvalidChars(&value); break; case 11: // Another header name not in the static table, empty header value. @@ -525,11 +542,13 @@ spdy::Http2HeaderBlock GenerateHeaderList(QuicFuzzedDataProvider* provider) { // Another header name not in the static table, fuzzed header value. name = "bar"; value = provider->ConsumeRandomLengthString(128); + TruncateValueOnInvalidChars(&value); break; default: // Fuzzed header name and header value. name = provider->ConsumeRandomLengthString(128); value = provider->ConsumeRandomLengthString(128); + TruncateValueOnInvalidChars(&value); } header_list.AppendValueOrAddHeader(name, value); @@ -562,7 +581,7 @@ QuicHeaderList SplitHeaderList(const spdy::Http2HeaderBlock& header_list) { // encoding then decoding is expected to result in the original header list, and // this fuzzer checks for that. extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { - QuicFuzzedDataProvider provider(data, size); + FuzzedDataProvider provider(data, size); // Maximum 256 byte dynamic table. Such a small size helps test draining // entries and eviction. diff --git a/gquiche/quic/core/qpack/qpack_blocking_manager.h b/gquiche/quic/core/qpack/qpack_blocking_manager.h index 38f4ae84..6c0c5520 100644 --- a/gquiche/quic/core/qpack/qpack_blocking_manager.h +++ b/gquiche/quic/core/qpack/qpack_blocking_manager.h @@ -75,8 +75,8 @@ class QUIC_EXPORT_PRIVATE QpackBlockingManager { // A stream typically has only one header block, except for the rare cases of // 1xx responses, trailers, or push promises. Even if there are multiple // header blocks sent on a single stream, they might not be blocked at the - // same time. Use std::list instead of QuicCircularDeque because it has lower - // memory footprint when holding few elements. + // same time. Use std::list instead of quiche::QuicheCircularDeque because it + // has lower memory footprint when holding few elements. using HeaderBlocksForStream = std::list; using HeaderBlocks = absl::flat_hash_map; diff --git a/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc b/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc index 06286416..9c11bb28 100644 --- a/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc +++ b/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.cc @@ -67,13 +67,13 @@ void QpackDecodedHeadersAccumulator::OnDecodingCompleted() { } void QpackDecodedHeadersAccumulator::OnDecodingErrorDetected( - absl::string_view error_message) { + QuicErrorCode error_code, absl::string_view error_message) { QUICHE_DCHECK(!error_detected_); QUICHE_DCHECK(!headers_decoded_); error_detected_ = true; // Might destroy |this|. - visitor_->OnHeaderDecodingError(error_message); + visitor_->OnHeaderDecodingError(error_code, error_message); } void QpackDecodedHeadersAccumulator::Decode(absl::string_view data) { diff --git a/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.h b/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.h index 12067bdc..bff3213a 100644 --- a/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.h +++ b/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator.h @@ -11,6 +11,7 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/http/quic_header_list.h" #include "gquiche/quic/core/qpack/qpack_progressive_decoder.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_export.h" @@ -45,7 +46,8 @@ class QUIC_EXPORT_PRIVATE QpackDecodedHeadersAccumulator bool header_list_size_limit_exceeded) = 0; // Called when an error has occurred. - virtual void OnHeaderDecodingError(absl::string_view error_message) = 0; + virtual void OnHeaderDecodingError(QuicErrorCode error_code, + absl::string_view error_message) = 0; }; QpackDecodedHeadersAccumulator(QuicStreamId id, @@ -59,7 +61,8 @@ class QUIC_EXPORT_PRIVATE QpackDecodedHeadersAccumulator void OnHeaderDecoded(absl::string_view name, absl::string_view value) override; void OnDecodingCompleted() override; - void OnDecodingErrorDetected(absl::string_view error_message) override; + void OnDecodingErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) override; // Decode payload data. // Must not be called if an error has been detected. diff --git a/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc b/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc index 6ec6392b..7fb228e2 100644 --- a/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc +++ b/gquiche/quic/core/qpack/qpack_decoded_headers_accumulator_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using ::testing::_; using ::testing::ElementsAre; @@ -47,9 +46,8 @@ class MockVisitor : public QpackDecodedHeadersAccumulator::Visitor { OnHeadersDecoded, (QuicHeaderList headers, bool header_list_size_limit_exceeded), (override)); - MOCK_METHOD(void, - OnHeaderDecodingError, - (absl::string_view error_message), + MOCK_METHOD(void, OnHeaderDecodingError, + (QuicErrorCode error_code, absl::string_view error_message), (override)); }; @@ -79,7 +77,8 @@ class QpackDecodedHeadersAccumulatorTest : public QuicTest { // HEADERS frame payload must have a complete Header Block Prefix. TEST_F(QpackDecodedHeadersAccumulatorTest, EmptyPayload) { EXPECT_CALL(visitor_, - OnHeaderDecodingError(Eq("Incomplete header data prefix."))); + OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header data prefix."))); accumulator_.EndHeaderBlock(); } @@ -88,7 +87,8 @@ TEST_F(QpackDecodedHeadersAccumulatorTest, TruncatedHeaderBlockPrefix) { accumulator_.Decode(absl::HexStringToBytes("00")); EXPECT_CALL(visitor_, - OnHeaderDecodingError(Eq("Incomplete header data prefix."))); + OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header data prefix."))); accumulator_.EndHeaderBlock(); } @@ -111,14 +111,16 @@ TEST_F(QpackDecodedHeadersAccumulatorTest, EmptyHeaderList) { TEST_F(QpackDecodedHeadersAccumulatorTest, TruncatedPayload) { accumulator_.Decode(absl::HexStringToBytes("00002366")); - EXPECT_CALL(visitor_, OnHeaderDecodingError(Eq("Incomplete header block."))); + EXPECT_CALL(visitor_, OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header block."))); accumulator_.EndHeaderBlock(); } // This payload is invalid because it refers to a non-existing static entry. TEST_F(QpackDecodedHeadersAccumulatorTest, InvalidPayload) { EXPECT_CALL(visitor_, - OnHeaderDecodingError(Eq("Static table entry not found."))); + OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Static table entry not found."))); accumulator_.Decode(absl::HexStringToBytes("0000ff23ff24")); } @@ -242,7 +244,8 @@ TEST_F(QpackDecodedHeadersAccumulatorTest, qpack_decoder_.OnSetDynamicTableCapacity(kMaxDynamicTableCapacity); // Adding dynamic table entry unblocks decoding. Error is detected. - EXPECT_CALL(visitor_, OnHeaderDecodingError(Eq("Invalid relative index."))); + EXPECT_CALL(visitor_, OnHeaderDecodingError(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); qpack_decoder_.OnInsertWithoutNameReference("foo", "bar"); } diff --git a/gquiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc b/gquiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc index 1d02ed67..b5bc09c6 100644 --- a/gquiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc +++ b/gquiche/quic/core/qpack/qpack_decoder_stream_receiver_test.cc @@ -7,7 +7,6 @@ #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using testing::Eq; using testing::StrictMock; diff --git a/gquiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc b/gquiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc index 269c7ca0..f88a2ef9 100644 --- a/gquiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc +++ b/gquiche/quic/core/qpack/qpack_decoder_stream_sender_test.cc @@ -7,7 +7,6 @@ #include "absl/strings/escaping.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using ::testing::Eq; using ::testing::StrictMock; diff --git a/gquiche/quic/core/qpack/qpack_decoder_test.cc b/gquiche/quic/core/qpack/qpack_decoder_test.cc index d780fa42..4e4cbe22 100644 --- a/gquiche/quic/core/qpack/qpack_decoder_test.cc +++ b/gquiche/quic/core/qpack/qpack_decoder_test.cc @@ -13,7 +13,6 @@ #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/spdy/core/spdy_header_block.h" using ::testing::_; @@ -50,8 +49,9 @@ class QpackDecoderTest : public QuicTestWithParam { void SetUp() override { // Destroy QpackProgressiveDecoder on error to test that it does not crash. // See https://crbug.com/1025209. - ON_CALL(handler_, OnDecodingErrorDetected(_)) - .WillByDefault(Invoke([this](absl::string_view /* error_message */) { + ON_CALL(handler_, OnDecodingErrorDetected(_, _)) + .WillByDefault(Invoke([this](QuicErrorCode /* error_code */, + absl::string_view /* error_message */) { progressive_decoder_.reset(); })); } @@ -115,7 +115,8 @@ INSTANTIATE_TEST_SUITE_P(All, TEST_P(QpackDecoderTest, NoPrefix) { EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Incomplete header data prefix."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header data prefix."))); // Header Data Prefix is at least two bytes long. DecodeHeaderBlock(absl::HexStringToBytes("00")); @@ -127,7 +128,8 @@ TEST_P(QpackDecoderTest, InvalidPrefix) { StartDecoding(); EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Encoded integer too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Encoded integer too large."))); // Encoded Required Insert Count in Header Data Prefix is too large. DecodeData(absl::HexStringToBytes("ffffffffffffffffffffffffffff")); @@ -189,7 +191,8 @@ TEST_P(QpackDecoderTest, MultipleLiteralEntries) { // Name Length value is too large for varint decoder to decode. TEST_P(QpackDecoderTest, NameLenTooLargeForVarintDecoder) { EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Encoded integer too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Encoded integer too large."))); DecodeHeaderBlock(absl::HexStringToBytes("000027ffffffffffffffffffff")); } @@ -197,7 +200,8 @@ TEST_P(QpackDecoderTest, NameLenTooLargeForVarintDecoder) { // Name Length value can be decoded by varint decoder but exceeds 1 MB limit. TEST_P(QpackDecoderTest, NameLenExceedsLimit) { EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("String literal too long."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("String literal too long."))); DecodeHeaderBlock(absl::HexStringToBytes("000027ffff7f")); } @@ -205,7 +209,8 @@ TEST_P(QpackDecoderTest, NameLenExceedsLimit) { // Value Length value is too large for varint decoder to decode. TEST_P(QpackDecoderTest, ValueLenTooLargeForVarintDecoder) { EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Encoded integer too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Encoded integer too large."))); DecodeHeaderBlock( absl::HexStringToBytes("000023666f6f7fffffffffffffffffffff")); @@ -214,14 +219,24 @@ TEST_P(QpackDecoderTest, ValueLenTooLargeForVarintDecoder) { // Value Length value can be decoded by varint decoder but exceeds 1 MB limit. TEST_P(QpackDecoderTest, ValueLenExceedsLimit) { EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("String literal too long."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("String literal too long."))); DecodeHeaderBlock(absl::HexStringToBytes("000023666f6f7fffff7f")); } +TEST_P(QpackDecoderTest, LineFeedInValue) { + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_INVALID_CHARACTER_IN_FIELD_VALUE, + "Invalid character in field value.")); + + DecodeHeaderBlock(absl::HexStringToBytes("000023666f6f0462610a72")); +} + TEST_P(QpackDecoderTest, IncompleteHeaderBlock) { EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Incomplete header block."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Incomplete header block."))); DecodeHeaderBlock(absl::HexStringToBytes("00002366")); } @@ -252,8 +267,9 @@ TEST_P(QpackDecoderTest, AlternatingHuffmanNonHuffman) { } TEST_P(QpackDecoderTest, HuffmanNameDoesNotHaveEOSPrefix) { - EXPECT_CALL(handler_, OnDecodingErrorDetected(absl::string_view( - "Error in Huffman-encoded string."))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); // 'y' ends in 0b0 on the most significant bit of the last byte. // The remaining 7 bits must be a prefix of EOS, which is all 1s. @@ -262,8 +278,9 @@ TEST_P(QpackDecoderTest, HuffmanNameDoesNotHaveEOSPrefix) { } TEST_P(QpackDecoderTest, HuffmanValueDoesNotHaveEOSPrefix) { - EXPECT_CALL(handler_, OnDecodingErrorDetected(absl::string_view( - "Error in Huffman-encoded string."))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); // 'e' ends in 0b101, taking up the 3 most significant bits of the last byte. // The remaining 5 bits must be a prefix of EOS, which is all 1s. @@ -272,8 +289,9 @@ TEST_P(QpackDecoderTest, HuffmanValueDoesNotHaveEOSPrefix) { } TEST_P(QpackDecoderTest, HuffmanNameEOSPrefixTooLong) { - EXPECT_CALL(handler_, OnDecodingErrorDetected(absl::string_view( - "Error in Huffman-encoded string."))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); // The trailing EOS prefix must be at most 7 bits long. Appending one octet // with value 0xff is invalid, even though 0b111111111111111 (15 bits) is a @@ -283,8 +301,9 @@ TEST_P(QpackDecoderTest, HuffmanNameEOSPrefixTooLong) { } TEST_P(QpackDecoderTest, HuffmanValueEOSPrefixTooLong) { - EXPECT_CALL(handler_, OnDecodingErrorDetected(absl::string_view( - "Error in Huffman-encoded string."))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error in Huffman-encoded string."))); // The trailing EOS prefix must be at most 7 bits long. Appending one octet // with value 0xff is invalid, even though 0b1111111111111 (13 bits) is a @@ -322,7 +341,8 @@ TEST_P(QpackDecoderTest, TooHighStaticTableIndex) { // Addressing entry 99 should trigger an error. EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Static table entry not found."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Static table entry not found."))); DecodeHeaderBlock(absl::HexStringToBytes("0000ff23ff24")); } @@ -431,6 +451,7 @@ TEST_P(QpackDecoderTest, DecreasingDynamicTableCapacityEvictsEntries) { DecodeEncoderStreamData(absl::HexStringToBytes("3f01")); EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Dynamic table entry already evicted."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -498,7 +519,8 @@ TEST_P(QpackDecoderTest, EncoderStreamErrorTooLargeInteger) { } TEST_P(QpackDecoderTest, InvalidDynamicEntryWhenBaseIsZero) { - EXPECT_CALL(handler_, OnDecodingErrorDetected(Eq("Invalid relative index."))); + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); // Set dynamic table capacity to 1024. DecodeEncoderStreamData(absl::HexStringToBytes("3fe107")); @@ -513,7 +535,8 @@ TEST_P(QpackDecoderTest, InvalidDynamicEntryWhenBaseIsZero) { } TEST_P(QpackDecoderTest, InvalidNegativeBase) { - EXPECT_CALL(handler_, OnDecodingErrorDetected(Eq("Error calculating Base."))); + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Error calculating Base."))); // Required Insert Count 1, Delta Base 1 with sign bit set, Base would // be 1 - 1 - 1 = -1, but it is not allowed to be negative. @@ -526,7 +549,8 @@ TEST_P(QpackDecoderTest, InvalidDynamicEntryByRelativeIndex) { // Add literal entry with name "foo" and value "bar". DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); - EXPECT_CALL(handler_, OnDecodingErrorDetected(Eq("Invalid relative index."))); + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); DecodeHeaderBlock(absl::HexStringToBytes( "0200" // Required Insert Count 1 and Delta Base 0. @@ -534,7 +558,8 @@ TEST_P(QpackDecoderTest, InvalidDynamicEntryByRelativeIndex) { "81")); // Indexed Header Field instruction addressing relative index 1. // This is absolute index -1, which is invalid. - EXPECT_CALL(handler_, OnDecodingErrorDetected(Eq("Invalid relative index."))); + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); DecodeHeaderBlock(absl::HexStringToBytes( "0200" // Required Insert Count 1 and Delta Base 0. @@ -555,6 +580,7 @@ TEST_P(QpackDecoderTest, EvictedDynamicTableEntry) { DecodeEncoderStreamData(absl::HexStringToBytes("00000000")); EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Dynamic table entry already evicted."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -564,6 +590,7 @@ TEST_P(QpackDecoderTest, EvictedDynamicTableEntry) { // This is absolute index 1. Such entry does not exist. EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Dynamic table entry already evicted."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -574,6 +601,7 @@ TEST_P(QpackDecoderTest, EvictedDynamicTableEntry) { // entry does not exist. EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Dynamic table entry already evicted."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -584,6 +612,7 @@ TEST_P(QpackDecoderTest, EvictedDynamicTableEntry) { // does not exist. EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Dynamic table entry already evicted."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -615,6 +644,7 @@ TEST_P(QpackDecoderTest, InvalidEncodedRequiredInsertCount) { // Required Insert Count is decoded modulo 2 * MaxEntries, that is, modulo 64. // A value of 1 cannot be encoded as 65 even though it has the same remainder. EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Error decoding Required Insert Count."))); DecodeHeaderBlock(absl::HexStringToBytes("4100")); } @@ -623,6 +653,7 @@ TEST_P(QpackDecoderTest, InvalidEncodedRequiredInsertCount) { // after a Header Block Prefix with an invalid Encoded Required Insert Count. TEST_P(QpackDecoderTest, DataAfterInvalidEncodedRequiredInsertCount) { EXPECT_CALL(handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Error decoding Required Insert Count."))); // Header Block Prefix followed by some extra data. DecodeHeaderBlock(absl::HexStringToBytes("410000")); @@ -667,7 +698,8 @@ TEST_P(QpackDecoderTest, NonZeroRequiredInsertCountButNoDynamicEntries) { EXPECT_CALL(handler_, OnHeaderDecoded(Eq(":method"), Eq("GET"))); EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Required Insert Count too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); DecodeHeaderBlock(absl::HexStringToBytes( "0200" // Required Insert Count is 1. @@ -683,6 +715,7 @@ TEST_P(QpackDecoderTest, AddressEntryNotAllowedByRequiredInsertCount) { EXPECT_CALL( handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Absolute Index must be smaller than Required Insert Count."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -695,6 +728,7 @@ TEST_P(QpackDecoderTest, AddressEntryNotAllowedByRequiredInsertCount) { EXPECT_CALL( handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Absolute Index must be smaller than Required Insert Count."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -708,6 +742,7 @@ TEST_P(QpackDecoderTest, AddressEntryNotAllowedByRequiredInsertCount) { EXPECT_CALL( handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Absolute Index must be smaller than Required Insert Count."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -721,6 +756,7 @@ TEST_P(QpackDecoderTest, AddressEntryNotAllowedByRequiredInsertCount) { EXPECT_CALL( handler_, OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, Eq("Absolute Index must be smaller than Required Insert Count."))); DecodeHeaderBlock(absl::HexStringToBytes( @@ -744,7 +780,8 @@ TEST_P(QpackDecoderTest, PromisedRequiredInsertCountLargerThanActual) { EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Required Insert Count too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); DecodeHeaderBlock(absl::HexStringToBytes( "0300" // Required Insert Count 2 and Delta Base 0. @@ -756,7 +793,8 @@ TEST_P(QpackDecoderTest, PromisedRequiredInsertCountLargerThanActual) { EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq(""))); EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Required Insert Count too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); DecodeHeaderBlock(absl::HexStringToBytes( "0300" // Required Insert Count 2 and Delta Base 0. @@ -768,7 +806,8 @@ TEST_P(QpackDecoderTest, PromisedRequiredInsertCountLargerThanActual) { EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Required Insert Count too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); DecodeHeaderBlock(absl::HexStringToBytes( "0481" // Required Insert Count 3 and Delta Base 1 with sign bit set. @@ -780,7 +819,8 @@ TEST_P(QpackDecoderTest, PromisedRequiredInsertCountLargerThanActual) { EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq(""))); EXPECT_CALL(handler_, - OnDecodingErrorDetected(Eq("Required Insert Count too large."))); + OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Required Insert Count too large."))); DecodeHeaderBlock(absl::HexStringToBytes( "0481" // Required Insert Count 3 and Delta Base 1 with sign bit set. @@ -865,7 +905,8 @@ TEST_P(QpackDecoderTest, // Count of the header block. |handler_| methods are called immediately for // the already consumed part of the header block. EXPECT_CALL(handler_, OnHeaderDecoded(Eq("foo"), Eq("bar"))); - EXPECT_CALL(handler_, OnDecodingErrorDetected(Eq("Invalid relative index."))); + EXPECT_CALL(handler_, OnDecodingErrorDetected(QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Invalid relative index."))); DecodeEncoderStreamData(absl::HexStringToBytes("6294e703626172")); } @@ -906,8 +947,10 @@ TEST_P(QpackDecoderTest, TooManyBlockedStreams) { auto progressive_decoder1 = CreateProgressiveDecoder(/* stream_id = */ 1); progressive_decoder1->Decode(data); - EXPECT_CALL(handler_, OnDecodingErrorDetected(Eq( - "Limit on number of blocked streams exceeded."))); + EXPECT_CALL(handler_, + OnDecodingErrorDetected( + QUIC_QPACK_DECOMPRESSION_FAILED, + Eq("Limit on number of blocked streams exceeded."))); auto progressive_decoder2 = CreateProgressiveDecoder(/* stream_id = */ 2); progressive_decoder2->Decode(data); diff --git a/gquiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc b/gquiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc index 7699fb45..3398baa6 100644 --- a/gquiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc +++ b/gquiche/quic/core/qpack/qpack_encoder_stream_receiver_test.cc @@ -7,7 +7,6 @@ #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using testing::Eq; using testing::StrictMock; diff --git a/gquiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc b/gquiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc index c0b20288..8554092b 100644 --- a/gquiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc +++ b/gquiche/quic/core/qpack/qpack_encoder_stream_sender_test.cc @@ -7,7 +7,6 @@ #include "absl/strings/escaping.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using ::testing::Eq; using ::testing::StrictMock; diff --git a/gquiche/quic/core/qpack/qpack_encoder_test.cc b/gquiche/quic/core/qpack/qpack_encoder_test.cc index 336b872b..da64408b 100644 --- a/gquiche/quic/core/qpack/qpack_encoder_test.cc +++ b/gquiche/quic/core/qpack/qpack_encoder_test.cc @@ -15,7 +15,6 @@ #include "gquiche/quic/test_tools/qpack/qpack_encoder_peer.h" #include "gquiche/quic/test_tools/qpack/qpack_encoder_test_utils.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using ::testing::_; using ::testing::Eq; diff --git a/gquiche/quic/core/qpack/qpack_header_table.cc b/gquiche/quic/core/qpack/qpack_header_table.cc index ce77509b..2d09b4d9 100644 --- a/gquiche/quic/core/qpack/qpack_header_table.cc +++ b/gquiche/quic/core/qpack/qpack_header_table.cc @@ -10,87 +10,14 @@ namespace quic { -QpackHeaderTableBase::QpackHeaderTableBase() - : dynamic_table_size_(0), - dynamic_table_capacity_(0), - maximum_dynamic_table_capacity_(0), - max_entries_(0), - dropped_entry_count_(0), - dynamic_table_entry_referenced_(false) {} - -bool QpackHeaderTableBase::EntryFitsDynamicTableCapacity( - absl::string_view name, - absl::string_view value) const { - return QpackEntry::Size(name, value) <= dynamic_table_capacity_; -} - -uint64_t QpackHeaderTableBase::InsertEntry(absl::string_view name, - absl::string_view value) { - QUICHE_DCHECK(EntryFitsDynamicTableCapacity(name, value)); - - const uint64_t index = dropped_entry_count_ + dynamic_entries_.size(); - - // Copy name and value before modifying the container, because evicting - // entries or even inserting a new one might invalidate |name| or |value| if - // they point to an entry. - QpackEntry new_entry((std::string(name)), (std::string(value))); - const size_t entry_size = new_entry.Size(); - - EvictDownToCapacity(dynamic_table_capacity_ - entry_size); - - dynamic_table_size_ += entry_size; - dynamic_entries_.push_back(std::move(new_entry)); - - return index; -} - -bool QpackHeaderTableBase::SetDynamicTableCapacity(uint64_t capacity) { - if (capacity > maximum_dynamic_table_capacity_) { - return false; - } - - dynamic_table_capacity_ = capacity; - EvictDownToCapacity(capacity); - - QUICHE_DCHECK_LE(dynamic_table_size_, dynamic_table_capacity_); - - return true; -} - -bool QpackHeaderTableBase::SetMaximumDynamicTableCapacity( - uint64_t maximum_dynamic_table_capacity) { - if (maximum_dynamic_table_capacity_ == 0) { - maximum_dynamic_table_capacity_ = maximum_dynamic_table_capacity; - max_entries_ = maximum_dynamic_table_capacity / 32; - return true; - } - // If the value is already set, it should not be changed. - return maximum_dynamic_table_capacity == maximum_dynamic_table_capacity_; -} - -void QpackHeaderTableBase::RemoveEntryFromEnd() { - const uint64_t entry_size = dynamic_entries_.front().Size(); - QUICHE_DCHECK_GE(dynamic_table_size_, entry_size); - dynamic_table_size_ -= entry_size; - - dynamic_entries_.pop_front(); - ++dropped_entry_count_; -} - -void QpackHeaderTableBase::EvictDownToCapacity(uint64_t capacity) { - while (dynamic_table_size_ > capacity) { - QUICHE_DCHECK(!dynamic_entries_.empty()); - RemoveEntryFromEnd(); - } -} - QpackEncoderHeaderTable::QpackEncoderHeaderTable() : static_index_(ObtainQpackStaticTable().GetStaticIndex()), static_name_index_(ObtainQpackStaticTable().GetStaticNameIndex()) {} uint64_t QpackEncoderHeaderTable::InsertEntry(absl::string_view name, absl::string_view value) { - const uint64_t index = QpackHeaderTableBase::InsertEntry(name, value); + const uint64_t index = + QpackHeaderTableBase::InsertEntry(name, value); // Make name and value point to the new entry. name = dynamic_entries().back().name(); @@ -235,7 +162,7 @@ void QpackEncoderHeaderTable::RemoveEntryFromEnd() { dynamic_name_index_.erase(name_it); } - QpackHeaderTableBase::RemoveEntryFromEnd(); + QpackHeaderTableBase::RemoveEntryFromEnd(); } QpackDecoderHeaderTable::QpackDecoderHeaderTable() @@ -249,7 +176,8 @@ QpackDecoderHeaderTable::~QpackDecoderHeaderTable() { uint64_t QpackDecoderHeaderTable::InsertEntry(absl::string_view name, absl::string_view value) { - const uint64_t index = QpackHeaderTableBase::InsertEntry(name, value); + const uint64_t index = + QpackHeaderTableBase::InsertEntry(name, value); // Notify and deregister observers whose threshold is met, if any. while (!observers_.empty()) { diff --git a/gquiche/quic/core/qpack/qpack_header_table.h b/gquiche/quic/core/qpack/qpack_header_table.h index cffee215..792f4690 100644 --- a/gquiche/quic/core/qpack/qpack_header_table.h +++ b/gquiche/quic/core/qpack/qpack_header_table.h @@ -6,12 +6,11 @@ #define QUICHE_QUIC_CORE_QPACK_QPACK_HEADER_TABLE_H_ #include -#include -#include -#include +#include #include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/common/quiche_circular_deque.h" #include "gquiche/spdy/core/hpack/hpack_entry.h" #include "gquiche/spdy/core/hpack/hpack_header_table.h" @@ -21,10 +20,20 @@ using QpackEntry = spdy::HpackEntry; using QpackLookupEntry = spdy::HpackLookupEntry; constexpr size_t kQpackEntrySizeOverhead = spdy::kHpackEntrySizeOverhead; +// Encoder needs pointer stability for |dynamic_index_| and +// |dynamic_name_index_|. However, it does not need random access. +// TODO(b/182349990): Change to a more memory efficient container. +using QpackEncoderDynamicTable = std::deque; + +// Decoder needs random access for LookupEntry(). +// However, it does not need pointer stability. +using QpackDecoderDynamicTable = quiche::QuicheCircularDeque; + // This is a base class for encoder and decoder classes that manage the QPACK // static and dynamic tables. For dynamic entries, it only has a concept of // absolute indices. The caller needs to perform the necessary transformations // to and from relative indices and post-base indices. +template class QUIC_EXPORT_PRIVATE QpackHeaderTableBase { public: QpackHeaderTableBase(); @@ -91,7 +100,6 @@ class QUIC_EXPORT_PRIVATE QpackHeaderTableBase { // |dynamic_table_size_| and |dropped_entry_count_|. virtual void RemoveEntryFromEnd(); - using DynamicEntryTable = spdy::HpackHeaderTable::DynamicEntryTable; const DynamicEntryTable& dynamic_entries() const { return dynamic_entries_; } private: @@ -130,8 +138,92 @@ class QUIC_EXPORT_PRIVATE QpackHeaderTableBase { bool dynamic_table_entry_referenced_; }; +template +QpackHeaderTableBase::QpackHeaderTableBase() + : dynamic_table_size_(0), + dynamic_table_capacity_(0), + maximum_dynamic_table_capacity_(0), + max_entries_(0), + dropped_entry_count_(0), + dynamic_table_entry_referenced_(false) {} + +template +bool QpackHeaderTableBase::EntryFitsDynamicTableCapacity( + absl::string_view name, + absl::string_view value) const { + return QpackEntry::Size(name, value) <= dynamic_table_capacity_; +} + +template +uint64_t QpackHeaderTableBase::InsertEntry( + absl::string_view name, + absl::string_view value) { + QUICHE_DCHECK(EntryFitsDynamicTableCapacity(name, value)); + + const uint64_t index = dropped_entry_count_ + dynamic_entries_.size(); + + // Copy name and value before modifying the container, because evicting + // entries or even inserting a new one might invalidate |name| or |value| if + // they point to an entry. + QpackEntry new_entry((std::string(name)), (std::string(value))); + const size_t entry_size = new_entry.Size(); + + EvictDownToCapacity(dynamic_table_capacity_ - entry_size); + + dynamic_table_size_ += entry_size; + dynamic_entries_.push_back(std::move(new_entry)); + + return index; +} + +template +bool QpackHeaderTableBase::SetDynamicTableCapacity( + uint64_t capacity) { + if (capacity > maximum_dynamic_table_capacity_) { + return false; + } + + dynamic_table_capacity_ = capacity; + EvictDownToCapacity(capacity); + + QUICHE_DCHECK_LE(dynamic_table_size_, dynamic_table_capacity_); + + return true; +} + +template +bool QpackHeaderTableBase::SetMaximumDynamicTableCapacity( + uint64_t maximum_dynamic_table_capacity) { + if (maximum_dynamic_table_capacity_ == 0) { + maximum_dynamic_table_capacity_ = maximum_dynamic_table_capacity; + max_entries_ = maximum_dynamic_table_capacity / 32; + return true; + } + // If the value is already set, it should not be changed. + return maximum_dynamic_table_capacity == maximum_dynamic_table_capacity_; +} + +template +void QpackHeaderTableBase::RemoveEntryFromEnd() { + const uint64_t entry_size = dynamic_entries_.front().Size(); + QUICHE_DCHECK_GE(dynamic_table_size_, entry_size); + dynamic_table_size_ -= entry_size; + + dynamic_entries_.pop_front(); + ++dropped_entry_count_; +} + +template +void QpackHeaderTableBase::EvictDownToCapacity( + uint64_t capacity) { + while (dynamic_table_size_ > capacity) { + QUICHE_DCHECK(!dynamic_entries_.empty()); + RemoveEntryFromEnd(); + } +} + class QUIC_EXPORT_PRIVATE QpackEncoderHeaderTable - : public QpackHeaderTableBase { + : public QpackHeaderTableBase { public: // Result of header table lookup. enum class MatchType { kNameAndValue, kName, kNoMatch }; @@ -198,7 +290,7 @@ class QUIC_EXPORT_PRIVATE QpackEncoderHeaderTable }; class QUIC_EXPORT_PRIVATE QpackDecoderHeaderTable - : public QpackHeaderTableBase { + : public QpackHeaderTableBase { public: // Observer interface for dynamic table insertion. class QUIC_EXPORT_PRIVATE Observer { diff --git a/gquiche/quic/core/qpack/qpack_instruction_decoder.cc b/gquiche/quic/core/qpack/qpack_instruction_decoder.cc index b7da6227..7619a6aa 100644 --- a/gquiche/quic/core/qpack/qpack_instruction_decoder.cc +++ b/gquiche/quic/core/qpack/qpack_instruction_decoder.cc @@ -86,8 +86,6 @@ bool QpackInstructionDecoder::Decode(absl::string_view data) { return true; } } - - return true; } bool QpackInstructionDecoder::AtInstructionBoundary() const { diff --git a/gquiche/quic/core/qpack/qpack_instruction_decoder_test.cc b/gquiche/quic/core/qpack/qpack_instruction_decoder_test.cc index e945f16f..9793a4e0 100644 --- a/gquiche/quic/core/qpack/qpack_instruction_decoder_test.cc +++ b/gquiche/quic/core/qpack/qpack_instruction_decoder_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" using ::testing::_; using ::testing::Eq; diff --git a/gquiche/quic/core/qpack/qpack_instruction_encoder_test.cc b/gquiche/quic/core/qpack/qpack_instruction_encoder_test.cc index fadef178..267a650e 100644 --- a/gquiche/quic/core/qpack/qpack_instruction_encoder_test.cc +++ b/gquiche/quic/core/qpack/qpack_instruction_encoder_test.cc @@ -8,7 +8,6 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { namespace test { diff --git a/gquiche/quic/core/qpack/qpack_instructions.h b/gquiche/quic/core/qpack/qpack_instructions.h index 54f5ecef..44c432fb 100644 --- a/gquiche/quic/core/qpack/qpack_instructions.h +++ b/gquiche/quic/core/qpack/qpack_instructions.h @@ -194,10 +194,10 @@ class QUIC_EXPORT_PRIVATE QpackInstructionWithValues { QpackInstructionWithValues() = default; // |*instruction| is not owned. - const QpackInstruction* instruction_; - bool s_bit_; - uint64_t varint_; - uint64_t varint2_; + const QpackInstruction* instruction_ = nullptr; + bool s_bit_ = false; + uint64_t varint_ = 0; + uint64_t varint2_ = 0; absl::string_view name_; absl::string_view value_; }; diff --git a/gquiche/quic/core/qpack/qpack_progressive_decoder.cc b/gquiche/quic/core/qpack/qpack_progressive_decoder.cc index 61f0c101..2c4000aa 100644 --- a/gquiche/quic/core/qpack/qpack_progressive_decoder.cc +++ b/gquiche/quic/core/qpack/qpack_progressive_decoder.cc @@ -12,20 +12,27 @@ #include "gquiche/quic/core/qpack/qpack_index_conversions.h" #include "gquiche/quic/core/qpack/qpack_instructions.h" #include "gquiche/quic/core/qpack/qpack_required_insert_count.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" +#include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" namespace quic { +namespace { + +// The value argument passed to OnHeaderDecoded() is from an entry in the static +// table. +constexpr bool kValueFromStaticTable = true; + +} // anonymous namespace + QpackProgressiveDecoder::QpackProgressiveDecoder( - QuicStreamId stream_id, - BlockedStreamLimitEnforcer* enforcer, - DecodingCompletedVisitor* visitor, - QpackDecoderHeaderTable* header_table, + QuicStreamId stream_id, BlockedStreamLimitEnforcer* enforcer, + DecodingCompletedVisitor* visitor, QpackDecoderHeaderTable* header_table, HeadersHandlerInterface* handler) : stream_id_(stream_id), - prefix_decoder_( - std::make_unique(QpackPrefixLanguage(), - this)), + prefix_decoder_(std::make_unique( + QpackPrefixLanguage(), this)), instruction_decoder_(QpackRequestStreamLanguage(), this), enforcer_(enforcer), visitor_(visitor), @@ -89,14 +96,6 @@ void QpackProgressiveDecoder::EndHeaderBlock() { } } -void QpackProgressiveDecoder::OnError(absl::string_view error_message) { - QUICHE_DCHECK(!error_detected_); - - error_detected_ = true; - // Might destroy |this|. - handler_->OnDecodingErrorDetected(error_message); -} - bool QpackProgressiveDecoder::OnInstructionDecoded( const QpackInstruction* instruction) { if (instruction == QpackPrefixInstruction()) { @@ -126,10 +125,9 @@ bool QpackProgressiveDecoder::OnInstructionDecoded( void QpackProgressiveDecoder::OnInstructionDecodingError( QpackInstructionDecoder::ErrorCode /* error_code */, absl::string_view error_message) { - // Ignore |error_code|, because header block decoding errors trigger a - // RESET_STREAM frame which cannot carry an error code more granular than - // QPACK_DECOMPRESSION_FAILED. - OnError(error_message); + // Ignore |error_code| and always use QUIC_QPACK_DECOMPRESSION_FAILED to avoid + // having to define a new QuicErrorCode for every instruction decoder error. + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, error_message); } void QpackProgressiveDecoder::OnInsertCountReachedThreshold() { @@ -164,12 +162,13 @@ bool QpackProgressiveDecoder::DoIndexedHeaderFieldInstruction() { uint64_t absolute_index; if (!QpackRequestStreamRelativeIndexToAbsoluteIndex( instruction_decoder_.varint(), base_, &absolute_index)) { - OnError("Invalid relative index."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid relative index."); return false; } if (absolute_index >= required_insert_count_) { - OnError("Absolute Index must be smaller than Required Insert Count."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); return false; } @@ -180,36 +179,37 @@ bool QpackProgressiveDecoder::DoIndexedHeaderFieldInstruction() { auto entry = header_table_->LookupEntry(/* is_static = */ false, absolute_index); if (!entry) { - OnError("Dynamic table entry already evicted."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); return false; } header_table_->set_dynamic_table_entry_referenced(); - handler_->OnHeaderDecoded(entry->name(), entry->value()); - return true; + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), + entry->value()); } auto entry = header_table_->LookupEntry(/* is_static = */ true, instruction_decoder_.varint()); if (!entry) { - OnError("Static table entry not found."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Static table entry not found."); return false; } - handler_->OnHeaderDecoded(entry->name(), entry->value()); - return true; + return OnHeaderDecoded(kValueFromStaticTable, entry->name(), entry->value()); } bool QpackProgressiveDecoder::DoIndexedHeaderFieldPostBaseInstruction() { uint64_t absolute_index; if (!QpackPostBaseIndexToAbsoluteIndex(instruction_decoder_.varint(), base_, &absolute_index)) { - OnError("Invalid post-base index."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid post-base index."); return false; } if (absolute_index >= required_insert_count_) { - OnError("Absolute Index must be smaller than Required Insert Count."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); return false; } @@ -220,13 +220,13 @@ bool QpackProgressiveDecoder::DoIndexedHeaderFieldPostBaseInstruction() { auto entry = header_table_->LookupEntry(/* is_static = */ false, absolute_index); if (!entry) { - OnError("Dynamic table entry already evicted."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); return false; } header_table_->set_dynamic_table_entry_referenced(); - handler_->OnHeaderDecoded(entry->name(), entry->value()); - return true; + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), entry->value()); } bool QpackProgressiveDecoder::DoLiteralHeaderFieldNameReferenceInstruction() { @@ -234,12 +234,13 @@ bool QpackProgressiveDecoder::DoLiteralHeaderFieldNameReferenceInstruction() { uint64_t absolute_index; if (!QpackRequestStreamRelativeIndexToAbsoluteIndex( instruction_decoder_.varint(), base_, &absolute_index)) { - OnError("Invalid relative index."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid relative index."); return false; } if (absolute_index >= required_insert_count_) { - OnError("Absolute Index must be smaller than Required Insert Count."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); return false; } @@ -250,36 +251,38 @@ bool QpackProgressiveDecoder::DoLiteralHeaderFieldNameReferenceInstruction() { auto entry = header_table_->LookupEntry(/* is_static = */ false, absolute_index); if (!entry) { - OnError("Dynamic table entry already evicted."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); return false; } header_table_->set_dynamic_table_entry_referenced(); - handler_->OnHeaderDecoded(entry->name(), instruction_decoder_.value()); - return true; + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), + instruction_decoder_.value()); } auto entry = header_table_->LookupEntry(/* is_static = */ true, instruction_decoder_.varint()); if (!entry) { - OnError("Static table entry not found."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Static table entry not found."); return false; } - handler_->OnHeaderDecoded(entry->name(), instruction_decoder_.value()); - return true; + return OnHeaderDecoded(kValueFromStaticTable, entry->name(), + instruction_decoder_.value()); } bool QpackProgressiveDecoder::DoLiteralHeaderFieldPostBaseInstruction() { uint64_t absolute_index; if (!QpackPostBaseIndexToAbsoluteIndex(instruction_decoder_.varint(), base_, &absolute_index)) { - OnError("Invalid post-base index."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Invalid post-base index."); return false; } if (absolute_index >= required_insert_count_) { - OnError("Absolute Index must be smaller than Required Insert Count."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Absolute Index must be smaller than Required Insert Count."); return false; } @@ -290,20 +293,19 @@ bool QpackProgressiveDecoder::DoLiteralHeaderFieldPostBaseInstruction() { auto entry = header_table_->LookupEntry(/* is_static = */ false, absolute_index); if (!entry) { - OnError("Dynamic table entry already evicted."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Dynamic table entry already evicted."); return false; } header_table_->set_dynamic_table_entry_referenced(); - handler_->OnHeaderDecoded(entry->name(), instruction_decoder_.value()); - return true; + return OnHeaderDecoded(!kValueFromStaticTable, entry->name(), + instruction_decoder_.value()); } bool QpackProgressiveDecoder::DoLiteralHeaderFieldInstruction() { - handler_->OnHeaderDecoded(instruction_decoder_.name(), - instruction_decoder_.value()); - - return true; + return OnHeaderDecoded(!kValueFromStaticTable, instruction_decoder_.name(), + instruction_decoder_.value()); } bool QpackProgressiveDecoder::DoPrefixInstruction() { @@ -312,14 +314,15 @@ bool QpackProgressiveDecoder::DoPrefixInstruction() { if (!QpackDecodeRequiredInsertCount( prefix_decoder_->varint(), header_table_->max_entries(), header_table_->inserted_entry_count(), &required_insert_count_)) { - OnError("Error decoding Required Insert Count."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Error decoding Required Insert Count."); return false; } const bool sign = prefix_decoder_->s_bit(); const uint64_t delta_base = prefix_decoder_->varint2(); if (!DeltaBaseToBase(sign, delta_base, &base_)) { - OnError("Error calculating Base."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Error calculating Base."); return false; } @@ -327,7 +330,8 @@ bool QpackProgressiveDecoder::DoPrefixInstruction() { if (required_insert_count_ > header_table_->inserted_entry_count()) { if (!enforcer_->OnStreamBlocked(stream_id_)) { - OnError("Limit on number of blocked streams exceeded."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Limit on number of blocked streams exceeded."); return false; } blocked_ = true; @@ -337,6 +341,32 @@ bool QpackProgressiveDecoder::DoPrefixInstruction() { return true; } +bool QpackProgressiveDecoder::OnHeaderDecoded(bool value_from_static_table, + absl::string_view name, + absl::string_view value) { + // Skip test for static table entries as they are all known to be valid. + if (!value_from_static_table) { + // According to Section 10.3 of + // https://quicwg.org/base-drafts/draft-ietf-quic-http.html, + // "[...] HTTP/3 can transport field values that are not valid. While most + // values that can be encoded will not alter field parsing, carriage return + // (CR, ASCII 0x0d), line feed (LF, ASCII 0x0a), and the zero character + // (NUL, ASCII 0x00) might be exploited by an attacker if they are + // translated verbatim. Any request or response that contains a character + // not permitted in a field value MUST be treated as malformed [...]" + for (const auto c : value) { + if (c == '\0' || c == '\n' || c == '\r') { + OnError(QUIC_INVALID_CHARACTER_IN_FIELD_VALUE, + "Invalid character in field value."); + return false; + } + } + } + + handler_->OnHeaderDecoded(name, value); + return true; +} + void QpackProgressiveDecoder::FinishDecoding() { QUICHE_DCHECK(buffer_.empty()); QUICHE_DCHECK(!blocked_); @@ -347,17 +377,18 @@ void QpackProgressiveDecoder::FinishDecoding() { } if (!instruction_decoder_.AtInstructionBoundary()) { - OnError("Incomplete header block."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Incomplete header block."); return; } if (!prefix_decoded_) { - OnError("Incomplete header data prefix."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, "Incomplete header data prefix."); return; } if (required_insert_count_ != required_insert_count_so_far_) { - OnError("Required Insert Count too large."); + OnError(QUIC_QPACK_DECOMPRESSION_FAILED, + "Required Insert Count too large."); return; } @@ -365,6 +396,15 @@ void QpackProgressiveDecoder::FinishDecoding() { handler_->OnDecodingCompleted(); } +void QpackProgressiveDecoder::OnError(QuicErrorCode error_code, + absl::string_view error_message) { + QUICHE_DCHECK(!error_detected_); + + error_detected_ = true; + // Might destroy |this|. + handler_->OnDecodingErrorDetected(error_code, error_message); +} + bool QpackProgressiveDecoder::DeltaBaseToBase(bool sign, uint64_t delta_base, uint64_t* base) { diff --git a/gquiche/quic/core/qpack/qpack_progressive_decoder.h b/gquiche/quic/core/qpack/qpack_progressive_decoder.h index 9a31093c..8dc71e36 100644 --- a/gquiche/quic/core/qpack/qpack_progressive_decoder.h +++ b/gquiche/quic/core/qpack/qpack_progressive_decoder.h @@ -13,6 +13,7 @@ #include "gquiche/quic/core/qpack/qpack_encoder_stream_receiver.h" #include "gquiche/quic/core/qpack/qpack_header_table.h" #include "gquiche/quic/core/qpack/qpack_instruction_decoder.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_export.h" @@ -46,7 +47,8 @@ class QUIC_EXPORT_PRIVATE QpackProgressiveDecoder // Called when a decoding error has occurred. No other methods will be // called afterwards. Implementations are allowed to destroy // the QpackProgressiveDecoder instance synchronously. - virtual void OnDecodingErrorDetected(absl::string_view error_message) = 0; + virtual void OnDecodingErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) = 0; }; // Interface for keeping track of blocked streams for the purpose of enforcing @@ -95,9 +97,6 @@ class QUIC_EXPORT_PRIVATE QpackProgressiveDecoder // through Decode(). No methods must be called afterwards. void EndHeaderBlock(); - // Called on error. - void OnError(absl::string_view error_message); - // QpackInstructionDecoder::Delegate implementation. bool OnInstructionDecoded(const QpackInstruction* instruction) override; void OnInstructionDecodingError(QpackInstructionDecoder::ErrorCode error_code, @@ -115,9 +114,20 @@ class QUIC_EXPORT_PRIVATE QpackProgressiveDecoder bool DoLiteralHeaderFieldInstruction(); bool DoPrefixInstruction(); + // Called when an entry is decoded. Performs validation and calls + // HeadersHandlerInterface::OnHeaderDecoded() or OnError() as needed. Returns + // true if header value is valid, false otherwise. Skips validation if + // |value_from_static_table| is true, because static table entries are always + // valid. + bool OnHeaderDecoded(bool value_from_static_table, absl::string_view name, + absl::string_view value); + // Called as soon as EndHeaderBlock() is called and decoding is not blocked. void FinishDecoding(); + // Called on error. + void OnError(QuicErrorCode error_code, absl::string_view error_message); + // Calculates Base from |required_insert_count_|, which must be set before // calling this method, and sign bit and Delta Base in the Header Data Prefix, // which are passed in as arguments. Returns true on success, false on diff --git a/gquiche/quic/core/qpack/qpack_receive_stream.cc b/gquiche/quic/core/qpack/qpack_receive_stream.cc index d647109b..a13851a9 100644 --- a/gquiche/quic/core/qpack/qpack_receive_stream.cc +++ b/gquiche/quic/core/qpack/qpack_receive_stream.cc @@ -11,8 +11,7 @@ namespace quic { QpackReceiveStream::QpackReceiveStream(PendingStream* pending, QuicSession* session, QpackStreamReceiver* receiver) - : QuicStream(pending, session, READ_UNIDIRECTIONAL, /*is_static=*/true), - receiver_(receiver) {} + : QuicStream(pending, session, /*is_static=*/true), receiver_(receiver) {} void QpackReceiveStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { stream_delegate()->OnStreamError( diff --git a/gquiche/quic/core/qpack/qpack_send_stream.cc b/gquiche/quic/core/qpack/qpack_send_stream.cc index 26ca6f8a..beec2ce1 100644 --- a/gquiche/quic/core/qpack/qpack_send_stream.cc +++ b/gquiche/quic/core/qpack/qpack_send_stream.cc @@ -21,7 +21,7 @@ void QpackSendStream::OnStreamReset(const QuicRstStreamFrame& /*frame*/) { << "OnStreamReset() called for write unidirectional stream."; } -bool QpackSendStream::OnStopSending(QuicRstStreamErrorCode /* code */) { +bool QpackSendStream::OnStopSending(QuicResetStreamError /* code */) { stream_delegate()->OnStreamError( QUIC_HTTP_CLOSED_CRITICAL_STREAM, "STOP_SENDING received for QPACK send stream"); diff --git a/gquiche/quic/core/qpack/qpack_send_stream.h b/gquiche/quic/core/qpack/qpack_send_stream.h index b66d8d69..32cda6a0 100644 --- a/gquiche/quic/core/qpack/qpack_send_stream.h +++ b/gquiche/quic/core/qpack/qpack_send_stream.h @@ -33,7 +33,7 @@ class QUIC_EXPORT_PRIVATE QpackSendStream : public QuicStream, // Overriding QuicStream::OnStopSending() to make sure QPACK stream is never // closed before connection. void OnStreamReset(const QuicRstStreamFrame& frame) override; - bool OnStopSending(QuicRstStreamErrorCode code) override; + bool OnStopSending(QuicResetStreamError code) override; // The send QPACK stream is write unidirectional, so this method // should never be called. diff --git a/gquiche/quic/core/qpack/qpack_send_stream_test.cc b/gquiche/quic/core/qpack/qpack_send_stream_test.cc index 224aec26..626d5879 100644 --- a/gquiche/quic/core/qpack/qpack_send_stream_test.cc +++ b/gquiche/quic/core/qpack/qpack_send_stream_test.cc @@ -119,7 +119,8 @@ TEST_P(QpackSendStreamTest, WriteStreamTypeOnlyFirstTime) { TEST_P(QpackSendStreamTest, StopSendingQpackStream) { EXPECT_CALL(*connection_, CloseConnection(QUIC_HTTP_CLOSED_CRITICAL_STREAM, _, _)); - qpack_send_stream_->OnStopSending(QUIC_STREAM_CANCELLED); + qpack_send_stream_->OnStopSending( + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED)); } TEST_P(QpackSendStreamTest, ReceiveDataOnSendStream) { diff --git a/gquiche/quic/core/quic_alarm.cc b/gquiche/quic/core/quic_alarm.cc index 11789e3e..09a4a293 100644 --- a/gquiche/quic/core/quic_alarm.cc +++ b/gquiche/quic/core/quic_alarm.cc @@ -4,30 +4,60 @@ #include "gquiche/quic/core/quic_alarm.h" +#include + +#include "gquiche/quic/platform/api/quic_bug_tracker.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" +#include "gquiche/quic/platform/api/quic_flags.h" +#include "gquiche/quic/platform/api/quic_stack_trace.h" + namespace quic { QuicAlarm::QuicAlarm(QuicArenaScopedPtr delegate) : delegate_(std::move(delegate)), deadline_(QuicTime::Zero()) {} -QuicAlarm::~QuicAlarm() {} +QuicAlarm::~QuicAlarm() { + if (IsSet()) { + QUIC_CODE_COUNT(quic_alarm_not_cancelled_in_dtor); + } +} void QuicAlarm::Set(QuicTime new_deadline) { QUICHE_DCHECK(!IsSet()); QUICHE_DCHECK(new_deadline.IsInitialized()); + + if (IsPermanentlyCancelled()) { + QUIC_BUG(quic_alarm_illegal_set) + << "Set called after alarm is permanently cancelled. new_deadline:" + << new_deadline; + return; + } + deadline_ = new_deadline; SetImpl(); } -void QuicAlarm::Cancel() { - if (!IsSet()) { - // Don't try to cancel an alarm that hasn't been set. - return; +void QuicAlarm::CancelInternal(bool permanent) { + if (IsSet()) { + deadline_ = QuicTime::Zero(); + CancelImpl(); + } + + if (permanent) { + delegate_.reset(); } - deadline_ = QuicTime::Zero(); - CancelImpl(); } +bool QuicAlarm::IsPermanentlyCancelled() const { return delegate_ == nullptr; } + void QuicAlarm::Update(QuicTime new_deadline, QuicTime::Delta granularity) { + if (IsPermanentlyCancelled()) { + QUIC_BUG(quic_alarm_illegal_update) + << "Update called after alarm is permanently cancelled. new_deadline:" + << new_deadline << ", granularity:" << granularity; + return; + } + if (!new_deadline.IsInitialized()) { Cancel(); return; @@ -55,7 +85,11 @@ void QuicAlarm::Fire() { } deadline_ = QuicTime::Zero(); - delegate_->OnAlarm(); + if (!IsPermanentlyCancelled()) { + QuicConnectionContextSwitcher context_switcher( + delegate_->GetConnectionContext()); + delegate_->OnAlarm(); + } } void QuicAlarm::UpdateImpl() { diff --git a/gquiche/quic/core/quic_alarm.h b/gquiche/quic/core/quic_alarm.h index 7e15a849..36323a57 100644 --- a/gquiche/quic/core/quic_alarm.h +++ b/gquiche/quic/core/quic_alarm.h @@ -6,6 +6,7 @@ #define QUICHE_QUIC_CORE_QUIC_ALARM_H_ #include "gquiche/quic/core/quic_arena_scoped_ptr.h" +#include "gquiche/quic/core/quic_connection_context.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/platform/api/quic_export.h" @@ -22,10 +23,39 @@ class QUIC_EXPORT_PRIVATE QuicAlarm { public: virtual ~Delegate() {} + // If the alarm belongs to a single QuicConnection, return the corresponding + // QuicConnection.context_. Note the context_ is the first member of + // QuicConnection, so it should outlive the delegate. + // Otherwise return nullptr. + // The OnAlarm function will be called under the connection context, if any. + virtual QuicConnectionContext* GetConnectionContext() = 0; + // Invoked when the alarm fires. virtual void OnAlarm() = 0; }; + // DelegateWithContext is a Delegate with a QuicConnectionContext* stored as a + // member variable. + class QUIC_EXPORT_PRIVATE DelegateWithContext : public Delegate { + public: + explicit DelegateWithContext(QuicConnectionContext* context) + : context_(context) {} + ~DelegateWithContext() override {} + QuicConnectionContext* GetConnectionContext() override { return context_; } + + private: + QuicConnectionContext* context_; + }; + + // DelegateWithoutContext is a Delegate that does not have a corresponding + // context. Typically this means one object of the child class deals with many + // connections. + class QUIC_EXPORT_PRIVATE DelegateWithoutContext : public Delegate { + public: + ~DelegateWithoutContext() override {} + QuicConnectionContext* GetConnectionContext() override { return nullptr; } + }; + explicit QuicAlarm(QuicArenaScopedPtr delegate); QuicAlarm(const QuicAlarm&) = delete; QuicAlarm& operator=(const QuicAlarm&) = delete; @@ -36,11 +66,18 @@ class QUIC_EXPORT_PRIVATE QuicAlarm { // then Set(). void Set(QuicTime new_deadline); - // Cancels the alarm. May be called repeatedly. Does not - // guarantee that the underlying scheduling system will remove - // the alarm's associated task, but guarantees that the - // delegates OnAlarm method will not be called. - void Cancel(); + // Both PermanentCancel() and Cancel() can cancel the alarm. If permanent, + // future calls to Set() and Update() will become no-op except emitting an + // error log. + // + // Both may be called repeatedly. Does not guarantee that the underlying + // scheduling system will remove the alarm's associated task, but guarantees + // that the delegates OnAlarm method will not be called. + void PermanentCancel() { CancelInternal(true); } + void Cancel() { CancelInternal(false); } + + // Return true if PermanentCancel() has been called. + bool IsPermanentlyCancelled() const; // Cancels and sets the alarm if the |deadline| is farther from the current // deadline than |granularity|, and otherwise does nothing. If |deadline| is @@ -77,6 +114,8 @@ class QUIC_EXPORT_PRIVATE QuicAlarm { void Fire(); private: + void CancelInternal(bool permanent); + QuicArenaScopedPtr delegate_; QuicTime deadline_; }; diff --git a/gquiche/quic/core/quic_alarm_test.cc b/gquiche/quic/core/quic_alarm_test.cc index 16b77cca..2f6337c8 100644 --- a/gquiche/quic/core/quic_alarm_test.cc +++ b/gquiche/quic/core/quic_alarm_test.cc @@ -4,20 +4,41 @@ #include "gquiche/quic/core/quic_alarm.h" +#include "gquiche/quic/core/quic_connection_context.h" +#include "gquiche/quic/platform/api/quic_expect_bug.h" #include "gquiche/quic/platform/api/quic_test.h" +using testing::ElementsAre; using testing::Invoke; +using testing::Return; namespace quic { namespace test { namespace { +class TraceCollector : public QuicConnectionTracer { + public: + ~TraceCollector() override = default; + + void PrintLiteral(const char* literal) override { trace_.push_back(literal); } + + void PrintString(absl::string_view s) override { + trace_.push_back(std::string(s)); + } + + const std::vector& trace() const { return trace_; } + + private: + std::vector trace_; +}; + class MockDelegate : public QuicAlarm::Delegate { public: + MOCK_METHOD(QuicConnectionContext*, GetConnectionContext, (), (override)); MOCK_METHOD(void, OnAlarm, (), (override)); }; -class DestructiveDelegate : public QuicAlarm::Delegate { +class DestructiveDelegate : public QuicAlarm::DelegateWithoutContext { public: DestructiveDelegate() : alarm_(nullptr) {} @@ -111,6 +132,29 @@ TEST_F(QuicAlarmTest, Cancel) { EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); } +TEST_F(QuicAlarmTest, PermanentCancel) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + alarm_.PermanentCancel(); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); + + EXPECT_QUIC_BUG(alarm_.Set(deadline), + "Set called after alarm is permanently cancelled"); + EXPECT_TRUE(alarm_.IsPermanentlyCancelled()); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); + + EXPECT_QUIC_BUG(alarm_.Update(deadline, QuicTime::Delta::Zero()), + "Update called after alarm is permanently cancelled"); + EXPECT_TRUE(alarm_.IsPermanentlyCancelled()); + EXPECT_FALSE(alarm_.IsSet()); + EXPECT_FALSE(alarm_.scheduled()); + EXPECT_EQ(QuicTime::Zero(), alarm_.deadline()); +} + TEST_F(QuicAlarmTest, Update) { QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); alarm_.Set(deadline); @@ -161,6 +205,57 @@ TEST_F(QuicAlarmTest, FireDestroysAlarm) { alarm->FireAlarm(); } +TEST_F(QuicAlarmTest, NullAlarmContext) { + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + + EXPECT_CALL(*delegate_, GetConnectionContext()).WillOnce(Return(nullptr)); + + EXPECT_CALL(*delegate_, OnAlarm()).WillOnce(Invoke([] { + QUIC_TRACELITERAL("Alarm fired."); + })); + alarm_.FireAlarm(); +} + +TEST_F(QuicAlarmTest, AlarmContextWithNullTracer) { + QuicConnectionContext context; + ASSERT_EQ(context.tracer, nullptr); + + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + + EXPECT_CALL(*delegate_, GetConnectionContext()).WillOnce(Return(&context)); + + EXPECT_CALL(*delegate_, OnAlarm()).WillOnce(Invoke([] { + QUIC_TRACELITERAL("Alarm fired."); + })); + alarm_.FireAlarm(); +} + +TEST_F(QuicAlarmTest, AlarmContextWithTracer) { + QuicConnectionContext context; + std::unique_ptr tracer = std::make_unique(); + const TraceCollector& tracer_ref = *tracer; + context.tracer = std::move(tracer); + + QuicTime deadline = QuicTime::Zero() + QuicTime::Delta::FromSeconds(7); + alarm_.Set(deadline); + + EXPECT_CALL(*delegate_, GetConnectionContext()).WillOnce(Return(&context)); + + EXPECT_CALL(*delegate_, OnAlarm()).WillOnce(Invoke([] { + QUIC_TRACELITERAL("Alarm fired."); + })); + + // Since |context| is not installed in the current thread, the messages before + // and after FireAlarm() should not be collected by |tracer|. + QUIC_TRACELITERAL("Should not be collected before alarm."); + alarm_.FireAlarm(); + QUIC_TRACELITERAL("Should not be collected after alarm."); + + EXPECT_THAT(tracer_ref.trace(), ElementsAre("Alarm fired.")); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/quic_buffer_allocator.h b/gquiche/quic/core/quic_buffer_allocator.h index 1582839e..ffb73461 100644 --- a/gquiche/quic/core/quic_buffer_allocator.h +++ b/gquiche/quic/core/quic_buffer_allocator.h @@ -9,6 +9,7 @@ #include +#include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_export.h" namespace quic { @@ -56,6 +57,58 @@ inline QuicUniqueBufferPtr MakeUniqueBuffer(QuicBufferAllocator* allocator, QuicBufferDeleter(allocator)); } +// QuicUniqueBufferPtr with a length attached to it. Similar to QuicMemSlice, +// except unlike QuicMemSlice, QuicBuffer is mutable and is not +// platform-specific. Also unlike QuicMemSlice, QuicBuffer can be empty. +class QUIC_EXPORT_PRIVATE QuicBuffer { + public: + QuicBuffer() : buffer_(nullptr, QuicBufferDeleter(nullptr)), size_(0) {} + QuicBuffer(QuicBufferAllocator* allocator, size_t size) + : buffer_(MakeUniqueBuffer(allocator, size)), size_(size) {} + + // Make sure the move constructor zeroes out the size field. + QuicBuffer(QuicBuffer&& other) + : buffer_(std::move(other.buffer_)), size_(other.size_) { + other.buffer_ = nullptr; + other.size_ = 0; + } + QuicBuffer& operator=(QuicBuffer&& other) { + buffer_ = std::move(other.buffer_); + size_ = other.size_; + + other.buffer_ = nullptr; + other.size_ = 0; + return *this; + } + + // Convenience method to initialize a QuicBuffer by copying from an existing + // one. + static QuicBuffer Copy(QuicBufferAllocator* allocator, + absl::string_view data) { + QuicBuffer result(allocator, data.size()); + memcpy(result.data(), data.data(), data.size()); + return result; + } + + const char* data() const { return buffer_.get(); } + char* data() { return buffer_.get(); } + size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + absl::string_view AsStringView() const { + return absl::string_view(data(), size()); + } + + // Releases the ownership of the underlying buffer. + QuicUniqueBufferPtr Release() { + size_ = 0; + return std::move(buffer_); + } + + private: + QuicUniqueBufferPtr buffer_; + size_t size_; +}; + } // namespace quic #endif // QUICHE_QUIC_CORE_QUIC_BUFFER_ALLOCATOR_H_ diff --git a/gquiche/quic/core/quic_buffered_packet_store.cc b/gquiche/quic/core/quic_buffered_packet_store.cc index c2572630..e8824660 100644 --- a/gquiche/quic/core/quic_buffered_packet_store.cc +++ b/gquiche/quic/core/quic_buffered_packet_store.cc @@ -8,7 +8,6 @@ #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/quic/platform/api/quic_map_util.h" namespace quic { @@ -25,7 +24,7 @@ static const size_t kMaxConnectionsWithoutCHLO = namespace { // This alarm removes expired entries in map each time this alarm fires. -class ConnectionExpireAlarm : public QuicAlarm::Delegate { +class ConnectionExpireAlarm : public QuicAlarm::DelegateWithoutContext { public: explicit ConnectionExpireAlarm(QuicBufferedPacketStore* store) : connection_store_(store) {} @@ -77,30 +76,27 @@ QuicBufferedPacketStore::QuicBufferedPacketStore( expiration_alarm_( alarm_factory->CreateAlarm(new ConnectionExpireAlarm(this))) {} -QuicBufferedPacketStore::~QuicBufferedPacketStore() {} +QuicBufferedPacketStore::~QuicBufferedPacketStore() { + if (expiration_alarm_ != nullptr) { + expiration_alarm_->PermanentCancel(); + } +} EnqueuePacketResult QuicBufferedPacketStore::EnqueuePacket( - QuicConnectionId connection_id, - bool ietf_quic, - const QuicReceivedPacket& packet, - QuicSocketAddress self_address, - QuicSocketAddress peer_address, - bool is_chlo, - const std::vector& alpns, - const absl::string_view sni, - const ParsedQuicVersion& version) { + QuicConnectionId connection_id, bool ietf_quic, + const QuicReceivedPacket& packet, QuicSocketAddress self_address, + QuicSocketAddress peer_address, const ParsedQuicVersion& version, + absl::optional parsed_chlo) { + const bool is_chlo = parsed_chlo.has_value(); QUIC_BUG_IF(quic_bug_12410_1, !GetQuicFlag(FLAGS_quic_allow_chlo_buffering)) << "Shouldn't buffer packets if disabled via flag."; QUIC_BUG_IF(quic_bug_12410_2, - is_chlo && QuicContainsKey(connections_with_chlo_, connection_id)) + is_chlo && connections_with_chlo_.contains(connection_id)) << "Shouldn't buffer duplicated CHLO on connection " << connection_id; - QUIC_BUG_IF(quic_bug_12410_3, !is_chlo && !alpns.empty()) - << "Shouldn't have an ALPN defined for a non-CHLO packet."; QUIC_BUG_IF(quic_bug_12410_4, is_chlo && !version.IsKnown()) << "Should have version for CHLO packet."; - const bool is_first_packet = - !QuicContainsKey(undecryptable_packets_, connection_id); + const bool is_first_packet = !undecryptable_packets_.contains(connection_id); if (is_first_packet) { if (ShouldNotBufferPacket(is_chlo)) { // Drop the packet if the upper limit of undecryptable packets has been @@ -112,17 +108,16 @@ EnqueuePacketResult QuicBufferedPacketStore::EnqueuePacket( undecryptable_packets_.back().second.ietf_quic = ietf_quic; undecryptable_packets_.back().second.version = version; } - QUICHE_CHECK(QuicContainsKey(undecryptable_packets_, connection_id)); + QUICHE_CHECK(undecryptable_packets_.contains(connection_id)); BufferedPacketList& queue = undecryptable_packets_.find(connection_id)->second; if (!is_chlo) { // If current packet is not CHLO, it might not be buffered because store // only buffers certain number of undecryptable packets per connection. - size_t num_non_chlo_packets = - QuicContainsKey(connections_with_chlo_, connection_id) - ? (queue.buffered_packets.size() - 1) - : queue.buffered_packets.size(); + size_t num_non_chlo_packets = connections_with_chlo_.contains(connection_id) + ? (queue.buffered_packets.size() - 1) + : queue.buffered_packets.size(); if (num_non_chlo_packets >= kDefaultMaxUndecryptablePackets) { // If there are kMaxBufferedPacketsPerConnection packets buffered up for // this connection, drop the current packet. @@ -142,8 +137,7 @@ EnqueuePacketResult QuicBufferedPacketStore::EnqueuePacket( // Add CHLO to the beginning of buffered packets so that it can be delivered // first later. queue.buffered_packets.push_front(std::move(new_entry)); - queue.alpns = alpns; - queue.sni = std::string(sni); + queue.parsed_chlo = std::move(parsed_chlo); connections_with_chlo_[connection_id] = false; // Dummy value. // Set the version of buffered packets of this connection on CHLO. queue.version = version; @@ -169,7 +163,7 @@ EnqueuePacketResult QuicBufferedPacketStore::EnqueuePacket( bool QuicBufferedPacketStore::HasBufferedPackets( QuicConnectionId connection_id) const { - return QuicContainsKey(undecryptable_packets_, connection_id); + return undecryptable_packets_.contains(connection_id); } bool QuicBufferedPacketStore::HasChlosBuffered() const { @@ -247,22 +241,24 @@ BufferedPacketList QuicBufferedPacketStore::DeliverPacketsForNextConnection( connections_with_chlo_.pop_front(); BufferedPacketList packets = DeliverPackets(*connection_id); - QUICHE_DCHECK(!packets.buffered_packets.empty()) - << "Try to deliver connectons without CHLO"; + QUICHE_DCHECK(!packets.buffered_packets.empty() && + packets.parsed_chlo.has_value()) + << "Try to deliver connectons without CHLO. # packets:" + << packets.buffered_packets.size() + << ", has_parsed_chlo:" << packets.parsed_chlo.has_value(); return packets; } bool QuicBufferedPacketStore::HasChloForConnection( QuicConnectionId connection_id) { - return QuicContainsKey(connections_with_chlo_, connection_id); + return connections_with_chlo_.contains(connection_id); } bool QuicBufferedPacketStore::IngestPacketForTlsChloExtraction( - const QuicConnectionId& connection_id, - const ParsedQuicVersion& version, - const QuicReceivedPacket& packet, - std::vector* out_alpns, - std::string* out_sni) { + const QuicConnectionId& connection_id, const ParsedQuicVersion& version, + const QuicReceivedPacket& packet, std::vector* out_alpns, + std::string* out_sni, bool* out_resumption_attempted, + bool* out_early_data_attempted) { QUICHE_DCHECK_NE(out_alpns, nullptr); QUICHE_DCHECK_NE(out_sni, nullptr); QUICHE_DCHECK_EQ(version.handshake_protocol, PROTOCOL_TLS1_3); @@ -276,8 +272,11 @@ bool QuicBufferedPacketStore::IngestPacketForTlsChloExtraction( if (!it->second.tls_chlo_extractor.HasParsedFullChlo()) { return false; } - *out_alpns = it->second.tls_chlo_extractor.alpns(); - *out_sni = it->second.tls_chlo_extractor.server_name(); + const TlsChloExtractor& tls_chlo_extractor = it->second.tls_chlo_extractor; + *out_alpns = tls_chlo_extractor.alpns(); + *out_sni = tls_chlo_extractor.server_name(); + *out_resumption_attempted = tls_chlo_extractor.resumption_attempted(); + *out_early_data_attempted = tls_chlo_extractor.early_data_attempted(); return true; } diff --git a/gquiche/quic/core/quic_buffered_packet_store.h b/gquiche/quic/core/quic_buffered_packet_store.h index db37f46f..fefb3b95 100644 --- a/gquiche/quic/core/quic_buffered_packet_store.h +++ b/gquiche/quic/core/quic_buffered_packet_store.h @@ -14,9 +14,9 @@ #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/core/tls_chlo_extractor.h" -#include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_socket_address.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -67,9 +67,8 @@ class QUIC_NO_EXPORT QuicBufferedPacketStore { std::list buffered_packets; QuicTime creation_time; - // The ALPNs from the CHLO, if found. - std::vector alpns; - std::string sni; + // |parsed_chlo| is set iff the entire CHLO has been received. + absl::optional parsed_chlo; // Indicating whether this is an IETF QUIC connection. bool ietf_quic; // If buffered_packets contains the CHLO, it is the version of the CHLO. @@ -78,9 +77,9 @@ class QUIC_NO_EXPORT QuicBufferedPacketStore { TlsChloExtractor tls_chlo_extractor; }; - using BufferedPacketMap = QuicLinkedHashMap; + using BufferedPacketMap = quiche::QuicheLinkedHashMap; class QUIC_NO_EXPORT VisitorInterface { public: @@ -101,18 +100,14 @@ class QUIC_NO_EXPORT QuicBufferedPacketStore { QuicBufferedPacketStore& operator=(const QuicBufferedPacketStore&) = delete; - // Adds a copy of packet into packet queue for given connection. - // TODO(danzh): Consider to split this method to EnqueueChlo() and - // EnqueueDataPacket(). - EnqueuePacketResult EnqueuePacket(QuicConnectionId connection_id, - bool ietf_quic, - const QuicReceivedPacket& packet, - QuicSocketAddress self_address, - QuicSocketAddress peer_address, - bool is_chlo, - const std::vector& alpns, - const absl::string_view sni, - const ParsedQuicVersion& version); + // Adds a copy of packet into the packet queue for given connection. If the + // packet is the last one of the CHLO, |parsed_chlo| will contain a parsed + // version of the CHLO. + EnqueuePacketResult EnqueuePacket( + QuicConnectionId connection_id, bool ietf_quic, + const QuicReceivedPacket& packet, QuicSocketAddress self_address, + QuicSocketAddress peer_address, const ParsedQuicVersion& version, + absl::optional parsed_chlo); // Returns true if there are any packets buffered for |connection_id|. bool HasBufferedPackets(QuicConnectionId connection_id) const; @@ -122,11 +117,16 @@ class QUIC_NO_EXPORT QuicBufferedPacketStore { // Returns whether we've now parsed a full multi-packet TLS CHLO. // When this returns true, |out_alpns| is populated with the list of ALPNs // extracted from the CHLO. |out_sni| is populated with the SNI tag in CHLO. + // |out_resumption_attempted| is populated if the CHLO has the + // 'pre_shared_key' TLS extension. |out_early_data_attempted| is populated if + // the CHLO has the 'early_data' TLS extension. bool IngestPacketForTlsChloExtraction(const QuicConnectionId& connection_id, const ParsedQuicVersion& version, const QuicReceivedPacket& packet, std::vector* out_alpns, - std::string* out_sni); + std::string* out_sni, + bool* out_resumption_attempted, + bool* out_early_data_attempted); // Returns the list of buffered packets for |connection_id| and removes them // from the store. Returns an empty list if no early arrived packets for this @@ -186,7 +186,7 @@ class QUIC_NO_EXPORT QuicBufferedPacketStore { // Keeps track of connection with CHLO buffered up already and the order they // arrive. - QuicLinkedHashMap + quiche::QuicheLinkedHashMap connections_with_chlo_; }; diff --git a/gquiche/quic/core/quic_buffered_packet_store_test.cc b/gquiche/quic/core/quic_buffered_packet_store_test.cc index ce29c500..6f897304 100644 --- a/gquiche/quic/core/quic_buffered_packet_store_test.cc +++ b/gquiche/quic/core/quic_buffered_packet_store_test.cc @@ -22,6 +22,10 @@ static const size_t kMaxConnectionsWithoutCHLO = namespace test { namespace { +const absl::optional kNoParsedChlo; +const absl::optional kDefaultParsedChlo = + absl::make_optional(); + using BufferedPacket = QuicBufferedPacketStore::BufferedPacket; using BufferedPacketList = QuicBufferedPacketStore::BufferedPacketList; using EnqueuePacketResult = QuicBufferedPacketStore::EnqueuePacketResult; @@ -70,13 +74,12 @@ class QuicBufferedPacketStoreTest : public QuicTest { TEST_F(QuicBufferedPacketStoreTest, SimpleEnqueueAndDeliverPacket) { QuicConnectionId connection_id = TestConnectionId(1); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); auto packets = store_.DeliverPackets(connection_id); const std::list& queue = packets.buffered_packets; ASSERT_EQ(1u, queue.size()); - // The alpn should be ignored for non-chlo packets. - ASSERT_TRUE(packets.alpns.empty()); + ASSERT_FALSE(packets.parsed_chlo.has_value()); // There is no valid version because CHLO has not arrived. EXPECT_EQ(invalid_version_, packets.version); // Check content of the only packet in the queue. @@ -93,9 +96,9 @@ TEST_F(QuicBufferedPacketStoreTest, DifferentPacketAddressOnOneConnection) { QuicSocketAddress addr_with_new_port(QuicIpAddress::Any4(), 256); QuicConnectionId connection_id = TestConnectionId(1); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - addr_with_new_port, false, {}, "", invalid_version_); + addr_with_new_port, invalid_version_, kNoParsedChlo); std::list queue = store_.DeliverPackets(connection_id).buffered_packets; ASSERT_EQ(2u, queue.size()); @@ -110,9 +113,9 @@ TEST_F(QuicBufferedPacketStoreTest, for (uint64_t conn_id = 1; conn_id <= num_connections; ++conn_id) { QuicConnectionId connection_id = TestConnectionId(conn_id); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); } // Deliver packets in reversed order. @@ -132,14 +135,15 @@ TEST_F(QuicBufferedPacketStoreTest, QuicConnectionId connection_id = TestConnectionId(1); // Arrived CHLO packet shouldn't affect how many non-CHLO pacekts store can // keep. - EXPECT_EQ(QuicBufferedPacketStore::SUCCESS, - store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, true, {}, "", valid_version_)); + EXPECT_EQ( + QuicBufferedPacketStore::SUCCESS, + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo)); for (size_t i = 1; i <= num_packets; ++i) { // Only first |kDefaultMaxUndecryptablePackets packets| will be buffered. EnqueuePacketResult result = store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); if (i <= kDefaultMaxUndecryptablePackets) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); } else { @@ -161,7 +165,7 @@ TEST_F(QuicBufferedPacketStoreTest, ReachNonChloConnectionUpperLimit) { QuicConnectionId connection_id = TestConnectionId(conn_id); EnqueuePacketResult result = store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); if (conn_id <= kMaxConnectionsWithoutCHLO) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); } else { @@ -190,8 +194,8 @@ TEST_F(QuicBufferedPacketStoreTest, for (uint64_t conn_id = 1; conn_id <= num_chlos; ++conn_id) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, store_.EnqueuePacket(TestConnectionId(conn_id), false, packet_, - self_address_, peer_address_, true, {}, "", - valid_version_)); + self_address_, peer_address_, valid_version_, + kDefaultParsedChlo)); } // Send data packets on another |kMaxConnectionsWithoutCHLO| connections. @@ -201,7 +205,7 @@ TEST_F(QuicBufferedPacketStoreTest, QuicConnectionId connection_id = TestConnectionId(conn_id); EnqueuePacketResult result = store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, true, {}, "", valid_version_); + peer_address_, valid_version_, kDefaultParsedChlo); if (conn_id <= kDefaultMaxConnectionsInStore) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, result); } else { @@ -217,7 +221,7 @@ TEST_F(QuicBufferedPacketStoreTest, EnqueueChloOnTooManyDifferentConnections) { EXPECT_EQ( EnqueuePacketResult::SUCCESS, store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_)); + peer_address_, invalid_version_, kNoParsedChlo)); } // Buffer CHLOs on other connections till store is full. @@ -226,7 +230,7 @@ TEST_F(QuicBufferedPacketStoreTest, EnqueueChloOnTooManyDifferentConnections) { QuicConnectionId connection_id = TestConnectionId(i); EnqueuePacketResult rs = store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, true, {}, "", valid_version_); + peer_address_, valid_version_, kDefaultParsedChlo); if (i <= kDefaultMaxConnectionsInStore) { EXPECT_EQ(EnqueuePacketResult::SUCCESS, rs); EXPECT_TRUE(store_.HasChloForConnection(connection_id)); @@ -240,10 +244,11 @@ TEST_F(QuicBufferedPacketStoreTest, EnqueueChloOnTooManyDifferentConnections) { // But buffering a CHLO belonging to a connection already has data packet // buffered in the store should success. This is the connection should be // delivered at last. - EXPECT_EQ(EnqueuePacketResult::SUCCESS, - store_.EnqueuePacket( - /*connection_id=*/TestConnectionId(1), false, packet_, - self_address_, peer_address_, true, {}, "", valid_version_)); + EXPECT_EQ( + EnqueuePacketResult::SUCCESS, + store_.EnqueuePacket( + /*connection_id=*/TestConnectionId(1), false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo)); EXPECT_TRUE(store_.HasChloForConnection( /*connection_id=*/TestConnectionId(1))); @@ -271,15 +276,16 @@ TEST_F(QuicBufferedPacketStoreTest, EnqueueChloOnTooManyDifferentConnections) { TEST_F(QuicBufferedPacketStoreTest, PacketQueueExpiredBeforeDelivery) { QuicConnectionId connection_id = TestConnectionId(1); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); - EXPECT_EQ(EnqueuePacketResult::SUCCESS, - store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, true, {}, "", valid_version_)); + peer_address_, invalid_version_, kNoParsedChlo); + EXPECT_EQ( + EnqueuePacketResult::SUCCESS, + store_.EnqueuePacket(connection_id, false, packet_, self_address_, + peer_address_, valid_version_, kDefaultParsedChlo)); QuicConnectionId connection_id2 = TestConnectionId(2); EXPECT_EQ( EnqueuePacketResult::SUCCESS, store_.EnqueuePacket(connection_id2, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_)); + peer_address_, invalid_version_, kNoParsedChlo)); // CHLO on connection 3 arrives 1ms later. clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); @@ -288,7 +294,8 @@ TEST_F(QuicBufferedPacketStoreTest, PacketQueueExpiredBeforeDelivery) { // connections. QuicSocketAddress another_client_address(QuicIpAddress::Any4(), 255); store_.EnqueuePacket(connection_id3, false, packet_, self_address_, - another_client_address, true, {}, "", valid_version_); + another_client_address, valid_version_, + kDefaultParsedChlo); // Advance clock to the time when connection 1 and 2 expires. clock_.AdvanceTime( @@ -320,9 +327,9 @@ TEST_F(QuicBufferedPacketStoreTest, PacketQueueExpiredBeforeDelivery) { // for them to expire. QuicConnectionId connection_id4 = TestConnectionId(4); store_.EnqueuePacket(connection_id4, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); store_.EnqueuePacket(connection_id4, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); clock_.AdvanceTime( QuicBufferedPacketStorePeer::expiration_alarm(&store_)->deadline() - clock_.ApproximateNow()); @@ -337,9 +344,9 @@ TEST_F(QuicBufferedPacketStoreTest, SimpleDiscardPackets) { // Enqueue some packets store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); EXPECT_FALSE(store_.HasChlosBuffered()); @@ -363,11 +370,11 @@ TEST_F(QuicBufferedPacketStoreTest, DiscardWithCHLOs) { // Enqueue some packets, which include a CHLO store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, true, {}, "", valid_version_); + peer_address_, valid_version_, kDefaultParsedChlo); store_.EnqueuePacket(connection_id, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); EXPECT_TRUE(store_.HasBufferedPackets(connection_id)); EXPECT_TRUE(store_.HasChlosBuffered()); @@ -392,12 +399,15 @@ TEST_F(QuicBufferedPacketStoreTest, MultipleDiscardPackets) { // Enqueue some packets for two connection IDs store_.EnqueuePacket(connection_id_1, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); store_.EnqueuePacket(connection_id_1, false, packet_, self_address_, - peer_address_, false, {}, "", invalid_version_); + peer_address_, invalid_version_, kNoParsedChlo); + + ParsedClientHello parsed_chlo; + parsed_chlo.alpns.push_back("h3"); + parsed_chlo.sni = TestHostname(); store_.EnqueuePacket(connection_id_2, false, packet_, self_address_, - peer_address_, true, {"h3"}, TestHostname(), - valid_version_); + peer_address_, valid_version_, parsed_chlo); EXPECT_TRUE(store_.HasBufferedPackets(connection_id_1)); EXPECT_TRUE(store_.HasBufferedPackets(connection_id_2)); EXPECT_TRUE(store_.HasChlosBuffered()); @@ -414,9 +424,9 @@ TEST_F(QuicBufferedPacketStoreTest, MultipleDiscardPackets) { EXPECT_TRUE(store_.HasBufferedPackets(connection_id_2)); auto packets = store_.DeliverPackets(connection_id_2); EXPECT_EQ(1u, packets.buffered_packets.size()); - ASSERT_EQ(1u, packets.alpns.size()); - EXPECT_EQ("h3", packets.alpns[0]); - EXPECT_EQ(TestHostname(), packets.sni); + ASSERT_EQ(1u, packets.parsed_chlo->alpns.size()); + EXPECT_EQ("h3", packets.parsed_chlo->alpns[0]); + EXPECT_EQ(TestHostname(), packets.parsed_chlo->sni); // Since connection_id_2's chlo arrives, verify version is set. EXPECT_EQ(valid_version_, packets.version); EXPECT_TRUE(store_.HasChlosBuffered()); diff --git a/gquiche/quic/core/quic_chaos_protector.cc b/gquiche/quic/core/quic_chaos_protector.cc new file mode 100644 index 00000000..a1d472ab --- /dev/null +++ b/gquiche/quic/core/quic_chaos_protector.cc @@ -0,0 +1,232 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/quic_chaos_protector.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "gquiche/quic/core/crypto/quic_random.h" +#include "gquiche/quic/core/frames/quic_crypto_frame.h" +#include "gquiche/quic/core/frames/quic_frame.h" +#include "gquiche/quic/core/frames/quic_padding_frame.h" +#include "gquiche/quic/core/frames/quic_ping_frame.h" +#include "gquiche/quic/core/quic_data_reader.h" +#include "gquiche/quic/core/quic_data_writer.h" +#include "gquiche/quic/core/quic_framer.h" +#include "gquiche/quic/core/quic_packets.h" +#include "gquiche/quic/core/quic_stream_frame_data_producer.h" +#include "gquiche/quic/platform/api/quic_bug_tracker.h" +#include "gquiche/common/platform/api/quiche_logging.h" + +namespace quic { + +QuicChaosProtector::QuicChaosProtector(const QuicCryptoFrame& crypto_frame, + int num_padding_bytes, + size_t packet_size, + QuicFramer* framer, + QuicRandom* random) + : packet_size_(packet_size), + crypto_data_length_(crypto_frame.data_length), + crypto_buffer_offset_(crypto_frame.offset), + level_(crypto_frame.level), + remaining_padding_bytes_(num_padding_bytes), + framer_(framer), + random_(random) { + QUICHE_DCHECK_NE(framer_, nullptr); + QUICHE_DCHECK_NE(framer_->data_producer(), nullptr); + QUICHE_DCHECK_NE(random_, nullptr); +} + +QuicChaosProtector::~QuicChaosProtector() { + DeleteFrames(&frames_); +} + +absl::optional QuicChaosProtector::BuildDataPacket( + const QuicPacketHeader& header, + char* buffer) { + if (!CopyCryptoDataToLocalBuffer()) { + return absl::nullopt; + } + SplitCryptoFrame(); + AddPingFrames(); + SpreadPadding(); + ReorderFrames(); + return BuildPacket(header, buffer); +} + +WriteStreamDataResult QuicChaosProtector::WriteStreamData( + QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* /*writer*/) { + QUIC_BUG(chaos stream) << "This should never be called; id " << id + << " offset " << offset << " data_length " + << data_length; + return STREAM_MISSING; +} + +bool QuicChaosProtector::WriteCryptoData(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) { + if (level != level_) { + QUIC_BUG(chaos bad level) << "Unexpected " << level << " != " << level_; + return false; + } + // This is `offset + data_length > buffer_offset_ + buffer_length_` + // but with integer overflow protection. + if (offset < crypto_buffer_offset_ || data_length > crypto_data_length_ || + offset - crypto_buffer_offset_ > crypto_data_length_ - data_length) { + QUIC_BUG(chaos bad lengths) + << "Unexpected buffer_offset_ " << crypto_buffer_offset_ << " offset " + << offset << " buffer_length_ " << crypto_data_length_ + << " data_length " << data_length; + return false; + } + writer->WriteBytes(&crypto_data_buffer_[offset - crypto_buffer_offset_], + data_length); + return true; +} + +bool QuicChaosProtector::CopyCryptoDataToLocalBuffer() { + crypto_frame_buffer_ = std::make_unique(packet_size_); + frames_.push_back(QuicFrame( + new QuicCryptoFrame(level_, crypto_buffer_offset_, crypto_data_length_))); + // We use |framer_| to serialize the CRYPTO frame in order to extract its + // data from the crypto data producer. This ensures that we reuse the + // usual serialization code path, but has the downside that we then need to + // parse the offset and length in order to skip over those fields. + QuicDataWriter writer(packet_size_, crypto_frame_buffer_.get()); + if (!framer_->AppendCryptoFrame(*frames_.front().crypto_frame, &writer)) { + QUIC_BUG(chaos write crypto data); + return false; + } + QuicDataReader reader(crypto_frame_buffer_.get(), writer.length()); + uint64_t parsed_offset, parsed_length; + if (!reader.ReadVarInt62(&parsed_offset) || + !reader.ReadVarInt62(&parsed_length)) { + QUIC_BUG(chaos parse crypto frame); + return false; + } + + absl::string_view crypto_data = reader.ReadRemainingPayload(); + crypto_data_buffer_ = crypto_data.data(); + + QUICHE_DCHECK_EQ(parsed_offset, crypto_buffer_offset_); + QUICHE_DCHECK_EQ(parsed_length, crypto_data_length_); + QUICHE_DCHECK_EQ(parsed_length, crypto_data.length()); + + return true; +} + +void QuicChaosProtector::SplitCryptoFrame() { + const int max_overhead_of_adding_a_crypto_frame = + static_cast(QuicFramer::GetMinCryptoFrameSize( + crypto_buffer_offset_ + crypto_data_length_, crypto_data_length_)); + // Pick a random number of CRYPTO frames to add. + constexpr uint64_t kMaxAddedCryptoFrames = 10; + const uint64_t num_added_crypto_frames = + random_->InsecureRandUint64() % (kMaxAddedCryptoFrames + 1); + for (uint64_t i = 0; i < num_added_crypto_frames; i++) { + if (remaining_padding_bytes_ < max_overhead_of_adding_a_crypto_frame) { + break; + } + // Pick a random frame and split it by shrinking the picked frame and + // moving the second half of its data to a new frame that is then appended + // to |frames|. + size_t frame_to_split_index = + random_->InsecureRandUint64() % frames_.size(); + QuicCryptoFrame* frame_to_split = + frames_[frame_to_split_index].crypto_frame; + if (frame_to_split->data_length <= 1) { + continue; + } + const int frame_to_split_old_overhead = + static_cast(QuicFramer::GetMinCryptoFrameSize( + frame_to_split->offset, frame_to_split->data_length)); + const QuicPacketLength frame_to_split_new_data_length = + 1 + (random_->InsecureRandUint64() % (frame_to_split->data_length - 1)); + const QuicPacketLength new_frame_data_length = + frame_to_split->data_length - frame_to_split_new_data_length; + const QuicStreamOffset new_frame_offset = + frame_to_split->offset + frame_to_split_new_data_length; + frame_to_split->data_length -= new_frame_data_length; + frames_.push_back(QuicFrame( + new QuicCryptoFrame(level_, new_frame_offset, new_frame_data_length))); + const int frame_to_split_new_overhead = + static_cast(QuicFramer::GetMinCryptoFrameSize( + frame_to_split->offset, frame_to_split->data_length)); + const int new_frame_overhead = + static_cast(QuicFramer::GetMinCryptoFrameSize( + new_frame_offset, new_frame_data_length)); + QUICHE_DCHECK_LE(frame_to_split_new_overhead, frame_to_split_old_overhead); + // Readjust padding based on increased overhead. + remaining_padding_bytes_ -= new_frame_overhead; + remaining_padding_bytes_ -= frame_to_split_new_overhead; + remaining_padding_bytes_ += frame_to_split_old_overhead; + } +} + +void QuicChaosProtector::AddPingFrames() { + if (remaining_padding_bytes_ == 0) { + return; + } + constexpr uint64_t kMaxAddedPingFrames = 10; + const uint64_t num_ping_frames = + random_->InsecureRandUint64() % + std::min(kMaxAddedPingFrames, remaining_padding_bytes_); + for (uint64_t i = 0; i < num_ping_frames; i++) { + frames_.push_back(QuicFrame(QuicPingFrame())); + } + remaining_padding_bytes_ -= static_cast(num_ping_frames); +} + +void QuicChaosProtector::ReorderFrames() { + // Walk the array backwards and swap each frame with a random earlier one. + for (size_t i = frames_.size() - 1; i > 0; i--) { + std::swap(frames_[i], frames_[random_->InsecureRandUint64() % (i + 1)]); + } +} + +void QuicChaosProtector::SpreadPadding() { + for (auto it = frames_.begin(); it != frames_.end(); ++it) { + const int padding_bytes_in_this_frame = + random_->InsecureRandUint64() % (remaining_padding_bytes_ + 1); + if (padding_bytes_in_this_frame <= 0) { + continue; + } + it = frames_.insert( + it, QuicFrame(QuicPaddingFrame(padding_bytes_in_this_frame))); + ++it; // Skip over the padding frame we just added. + remaining_padding_bytes_ -= padding_bytes_in_this_frame; + } + if (remaining_padding_bytes_ > 0) { + frames_.push_back(QuicFrame(QuicPaddingFrame(remaining_padding_bytes_))); + } +} + +absl::optional QuicChaosProtector::BuildPacket( + const QuicPacketHeader& header, + char* buffer) { + QuicStreamFrameDataProducer* original_data_producer = + framer_->data_producer(); + framer_->set_data_producer(this); + + size_t length = + framer_->BuildDataPacket(header, frames_, buffer, packet_size_, level_); + + framer_->set_data_producer(original_data_producer); + if (length == 0) { + return absl::nullopt; + } + return length; +} + +} // namespace quic diff --git a/gquiche/quic/core/quic_chaos_protector.h b/gquiche/quic/core/quic_chaos_protector.h new file mode 100644 index 00000000..d03d371d --- /dev/null +++ b/gquiche/quic/core/quic_chaos_protector.h @@ -0,0 +1,99 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CHAOS_PROTECTOR_H_ +#define QUICHE_QUIC_CORE_QUIC_CHAOS_PROTECTOR_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "gquiche/quic/core/crypto/quic_random.h" +#include "gquiche/quic/core/frames/quic_crypto_frame.h" +#include "gquiche/quic/core/frames/quic_frame.h" +#include "gquiche/quic/core/quic_data_writer.h" +#include "gquiche/quic/core/quic_framer.h" +#include "gquiche/quic/core/quic_packets.h" +#include "gquiche/quic/core/quic_stream_frame_data_producer.h" +#include "gquiche/quic/core/quic_types.h" + +namespace quic { + +namespace test { +class QuicChaosProtectorTest; +} + +// QuicChaosProtector will take a crypto frame and an amount of padding and +// build a data packet that will parse to something equivalent. +class QUIC_EXPORT_PRIVATE QuicChaosProtector + : public QuicStreamFrameDataProducer { + public: + // |framer| and |random| must be valid for the lifetime of QuicChaosProtector. + explicit QuicChaosProtector(const QuicCryptoFrame& crypto_frame, + int num_padding_bytes, + size_t packet_size, + QuicFramer* framer, + QuicRandom* random); + + ~QuicChaosProtector() override; + + QuicChaosProtector(const QuicChaosProtector&) = delete; + QuicChaosProtector(QuicChaosProtector&&) = delete; + QuicChaosProtector& operator=(const QuicChaosProtector&) = delete; + QuicChaosProtector& operator=(QuicChaosProtector&&) = delete; + + // Attempts to build a data packet with chaos protection. If an error occurs, + // then absl::nullopt is returned. Otherwise returns the serialized length. + absl::optional BuildDataPacket(const QuicPacketHeader& header, + char* buffer); + + // From QuicStreamFrameDataProducer. + WriteStreamDataResult WriteStreamData(QuicStreamId id, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* /*writer*/) override; + bool WriteCryptoData(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override; + + private: + friend class test::QuicChaosProtectorTest; + + // Allocate the crypto data buffer, create the CRYPTO frame and write the + // crypto data to our buffer. + bool CopyCryptoDataToLocalBuffer(); + + // Split the CRYPTO frame in |frames_| into one or more CRYPTO frames that + // collectively represent the same data. Adjusts padding to compensate. + void SplitCryptoFrame(); + + // Add a random number of PING frames to |frames_| and adjust padding. + void AddPingFrames(); + + // Randomly reorder |frames_|. + void ReorderFrames(); + + // Add PADDING frames randomly between all other frames. + void SpreadPadding(); + + // Serialize |frames_| using |framer_|. + absl::optional BuildPacket(const QuicPacketHeader& header, + char* buffer); + + size_t packet_size_; + std::unique_ptr crypto_frame_buffer_; + const char* crypto_data_buffer_ = nullptr; + QuicByteCount crypto_data_length_; + QuicStreamOffset crypto_buffer_offset_; + EncryptionLevel level_; + int remaining_padding_bytes_; + QuicFrames frames_; // Inner frames owned, will be deleted by destructor. + QuicFramer* framer_; // Unowned. + QuicRandom* random_; // Unowned. +}; + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CHAOS_PROTECTOR_H_ diff --git a/gquiche/quic/core/quic_chaos_protector_test.cc b/gquiche/quic/core/quic_chaos_protector_test.cc new file mode 100644 index 00000000..3b0570eb --- /dev/null +++ b/gquiche/quic/core/quic_chaos_protector_test.cc @@ -0,0 +1,230 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/quic_chaos_protector.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "gquiche/quic/core/frames/quic_crypto_frame.h" +#include "gquiche/quic/core/quic_connection_id.h" +#include "gquiche/quic/core/quic_framer.h" +#include "gquiche/quic/core/quic_packet_number.h" +#include "gquiche/quic/core/quic_packets.h" +#include "gquiche/quic/core/quic_stream_frame_data_producer.h" +#include "gquiche/quic/core/quic_time.h" +#include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/core/quic_versions.h" +#include "gquiche/quic/platform/api/quic_test.h" +#include "gquiche/quic/test_tools/mock_random.h" +#include "gquiche/quic/test_tools/quic_test_utils.h" +#include "gquiche/quic/test_tools/simple_quic_framer.h" + +namespace quic { +namespace test { + +class QuicChaosProtectorTest : public QuicTestWithParam, + public QuicStreamFrameDataProducer { + public: + QuicChaosProtectorTest() + : version_(GetParam()), + framer_({version_}, QuicTime::Zero(), Perspective::IS_CLIENT, + kQuicDefaultConnectionIdLength), + validation_framer_({version_}), + random_(/*base=*/3), + level_(ENCRYPTION_INITIAL), + crypto_offset_(0), + crypto_data_length_(100), + crypto_frame_(level_, crypto_offset_, crypto_data_length_), + num_padding_bytes_(50), + packet_size_(1000), + packet_buffer_(std::make_unique(packet_size_)) { + ReCreateChaosProtector(); + } + + void ReCreateChaosProtector() { + chaos_protector_ = std::make_unique( + crypto_frame_, num_padding_bytes_, packet_size_, + SetupHeaderAndFramers(), &random_); + } + + // From QuicStreamFrameDataProducer. + WriteStreamDataResult WriteStreamData(QuicStreamId /*id*/, + QuicStreamOffset /*offset*/, + QuicByteCount /*data_length*/, + QuicDataWriter* /*writer*/) override { + ADD_FAILURE() << "This should never be called"; + return STREAM_MISSING; + } + + // From QuicStreamFrameDataProducer. + bool WriteCryptoData(EncryptionLevel level, + QuicStreamOffset offset, + QuicByteCount data_length, + QuicDataWriter* writer) override { + EXPECT_EQ(level, level); + EXPECT_EQ(offset, crypto_offset_); + EXPECT_EQ(data_length, crypto_data_length_); + for (QuicByteCount i = 0; i < data_length; i++) { + EXPECT_TRUE(writer->WriteUInt8(static_cast(i & 0xFF))); + } + return true; + } + + protected: + QuicFramer* SetupHeaderAndFramers() { + // Setup header. + header_.destination_connection_id = TestConnectionId(); + header_.destination_connection_id_included = CONNECTION_ID_PRESENT; + header_.source_connection_id = EmptyQuicConnectionId(); + header_.source_connection_id_included = CONNECTION_ID_PRESENT; + header_.reset_flag = false; + header_.version_flag = true; + header_.has_possible_stateless_reset_token = false; + header_.packet_number_length = PACKET_4BYTE_PACKET_NUMBER; + header_.version = version_; + header_.packet_number = QuicPacketNumber(1); + header_.form = IETF_QUIC_LONG_HEADER_PACKET; + header_.long_packet_type = INITIAL; + header_.retry_token_length_length = VARIABLE_LENGTH_INTEGER_LENGTH_1; + header_.length_length = kQuicDefaultLongHeaderLengthLength; + // Setup validation framer. + validation_framer_.framer()->SetInitialObfuscators( + header_.destination_connection_id); + // Setup framer. + framer_.SetInitialObfuscators(header_.destination_connection_id); + framer_.set_data_producer(this); + return &framer_; + } + + void BuildEncryptAndParse() { + absl::optional length = + chaos_protector_->BuildDataPacket(header_, packet_buffer_.get()); + ASSERT_TRUE(length.has_value()); + ASSERT_GT(length.value(), 0u); + size_t encrypted_length = framer_.EncryptInPlace( + level_, header_.packet_number, + GetStartOfEncryptedData(framer_.transport_version(), header_), + length.value(), packet_size_, packet_buffer_.get()); + ASSERT_GT(encrypted_length, 0u); + ASSERT_TRUE(validation_framer_.ProcessPacket(QuicEncryptedPacket( + absl::string_view(packet_buffer_.get(), encrypted_length)))); + } + + void ResetOffset(QuicStreamOffset offset) { + crypto_offset_ = offset; + crypto_frame_.offset = offset; + ReCreateChaosProtector(); + } + + void ResetLength(QuicByteCount length) { + crypto_data_length_ = length; + crypto_frame_.data_length = length; + ReCreateChaosProtector(); + } + + ParsedQuicVersion version_; + QuicPacketHeader header_; + QuicFramer framer_; + SimpleQuicFramer validation_framer_; + MockRandom random_; + EncryptionLevel level_; + QuicStreamOffset crypto_offset_; + QuicByteCount crypto_data_length_; + QuicCryptoFrame crypto_frame_; + int num_padding_bytes_; + size_t packet_size_; + std::unique_ptr packet_buffer_; + std::unique_ptr chaos_protector_; +}; + +namespace { + +ParsedQuicVersionVector TestVersions() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : AllSupportedVersions()) { + if (version.UsesCryptoFrames()) { + versions.push_back(version); + } + } + return versions; +} + +INSTANTIATE_TEST_SUITE_P(QuicChaosProtectorTests, + QuicChaosProtectorTest, + ::testing::ValuesIn(TestVersions()), + ::testing::PrintToStringParamName()); + +TEST_P(QuicChaosProtectorTest, Main) { + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 4u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, 0u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, 1u); + ASSERT_EQ(validation_framer_.ping_frames().size(), 3u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 7u); + EXPECT_EQ(validation_framer_.padding_frames()[0].num_padding_bytes, 3); +} + +TEST_P(QuicChaosProtectorTest, DifferentRandom) { + random_.ResetBase(4); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 4u); + ASSERT_EQ(validation_framer_.ping_frames().size(), 4u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 8u); +} + +TEST_P(QuicChaosProtectorTest, RandomnessZero) { + random_.ResetBase(0); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 1u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, + crypto_data_length_); + ASSERT_EQ(validation_framer_.ping_frames().size(), 0u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 1u); +} + +TEST_P(QuicChaosProtectorTest, Offset) { + ResetOffset(123); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 4u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, 1u); + ASSERT_EQ(validation_framer_.ping_frames().size(), 3u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 7u); + EXPECT_EQ(validation_framer_.padding_frames()[0].num_padding_bytes, 3); +} + +TEST_P(QuicChaosProtectorTest, OffsetAndRandomnessZero) { + ResetOffset(123); + random_.ResetBase(0); + BuildEncryptAndParse(); + ASSERT_EQ(validation_framer_.crypto_frames().size(), 1u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, + crypto_data_length_); + ASSERT_EQ(validation_framer_.ping_frames().size(), 0u); + ASSERT_EQ(validation_framer_.padding_frames().size(), 1u); +} + +TEST_P(QuicChaosProtectorTest, ZeroRemainingBytesAfterSplit) { + QuicPacketLength new_length = 63; + num_padding_bytes_ = QuicFramer::GetMinCryptoFrameSize( + crypto_frame_.offset + new_length, new_length); + ResetLength(new_length); + BuildEncryptAndParse(); + + ASSERT_EQ(validation_framer_.crypto_frames().size(), 2u); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->offset, crypto_offset_); + EXPECT_EQ(validation_framer_.crypto_frames()[0]->data_length, 4); + EXPECT_EQ(validation_framer_.crypto_frames()[1]->offset, crypto_offset_ + 4); + EXPECT_EQ(validation_framer_.crypto_frames()[1]->data_length, + crypto_data_length_ - 4); + ASSERT_EQ(validation_framer_.ping_frames().size(), 0u); +} + +} // namespace +} // namespace test +} // namespace quic diff --git a/gquiche/quic/core/quic_connection.cc b/gquiche/quic/core/quic_connection.cc index 54fdaed2..7c931f34 100644 --- a/gquiche/quic/core/quic_connection.cc +++ b/gquiche/quic/core/quic_connection.cc @@ -45,10 +45,10 @@ #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_hostname_utils.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_server_stats.h" #include "gquiche/quic/platform/api/quic_socket_address.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/platform/api/quiche_flag_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { @@ -63,13 +63,27 @@ const QuicPacketCount kMaxConsecutiveNonRetransmittablePackets = 19; // The minimum release time into future in ms. const int kMinReleaseTimeIntoFutureMs = 1; -// An alarm that is scheduled to send an ack if a timeout occurs. -class AckAlarmDelegate : public QuicAlarm::Delegate { +// Base class of all alarms owned by a QuicConnection. +class QuicConnectionAlarmDelegate : public QuicAlarm::Delegate { public: - explicit AckAlarmDelegate(QuicConnection* connection) + explicit QuicConnectionAlarmDelegate(QuicConnection* connection) : connection_(connection) {} - AckAlarmDelegate(const AckAlarmDelegate&) = delete; - AckAlarmDelegate& operator=(const AckAlarmDelegate&) = delete; + QuicConnectionAlarmDelegate(const QuicConnectionAlarmDelegate&) = delete; + QuicConnectionAlarmDelegate& operator=(const QuicConnectionAlarmDelegate&) = + delete; + + QuicConnectionContext* GetConnectionContext() override { + return (connection_ == nullptr) ? nullptr : connection_->context(); + } + + protected: + QuicConnection* connection_; +}; + +// An alarm that is scheduled to send an ack if a timeout occurs. +class AckAlarmDelegate : public QuicConnectionAlarmDelegate { + public: + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->ack_frame_updated()); @@ -81,136 +95,86 @@ class AckAlarmDelegate : public QuicAlarm::Delegate { connection_->SendAck(); } } - - private: - QuicConnection* connection_; }; // This alarm will be scheduled any time a data-bearing packet is sent out. // When the alarm goes off, the connection checks to see if the oldest packets // have been acked, and retransmit them if they have not. -class RetransmissionAlarmDelegate : public QuicAlarm::Delegate { +class RetransmissionAlarmDelegate : public QuicConnectionAlarmDelegate { public: - explicit RetransmissionAlarmDelegate(QuicConnection* connection) - : connection_(connection) {} - RetransmissionAlarmDelegate(const RetransmissionAlarmDelegate&) = delete; - RetransmissionAlarmDelegate& operator=(const RetransmissionAlarmDelegate&) = - delete; + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->connected()); connection_->OnRetransmissionTimeout(); } - - private: - QuicConnection* connection_; }; // An alarm that is scheduled when the SentPacketManager requires a delay // before sending packets and fires when the packet may be sent. -class SendAlarmDelegate : public QuicAlarm::Delegate { +class SendAlarmDelegate : public QuicConnectionAlarmDelegate { public: - explicit SendAlarmDelegate(QuicConnection* connection) - : connection_(connection) {} - SendAlarmDelegate(const SendAlarmDelegate&) = delete; - SendAlarmDelegate& operator=(const SendAlarmDelegate&) = delete; + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->connected()); connection_->WriteIfNotBlocked(); } - - private: - QuicConnection* connection_; }; -class PingAlarmDelegate : public QuicAlarm::Delegate { +class PingAlarmDelegate : public QuicConnectionAlarmDelegate { public: - explicit PingAlarmDelegate(QuicConnection* connection) - : connection_(connection) {} - PingAlarmDelegate(const PingAlarmDelegate&) = delete; - PingAlarmDelegate& operator=(const PingAlarmDelegate&) = delete; + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->connected()); connection_->OnPingTimeout(); } - - private: - QuicConnection* connection_; }; -class MtuDiscoveryAlarmDelegate : public QuicAlarm::Delegate { +class MtuDiscoveryAlarmDelegate : public QuicConnectionAlarmDelegate { public: - explicit MtuDiscoveryAlarmDelegate(QuicConnection* connection) - : connection_(connection) {} - MtuDiscoveryAlarmDelegate(const MtuDiscoveryAlarmDelegate&) = delete; - MtuDiscoveryAlarmDelegate& operator=(const MtuDiscoveryAlarmDelegate&) = - delete; + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->connected()); connection_->DiscoverMtu(); } - - private: - QuicConnection* connection_; }; -class ProcessUndecryptablePacketsAlarmDelegate : public QuicAlarm::Delegate { +class ProcessUndecryptablePacketsAlarmDelegate + : public QuicConnectionAlarmDelegate { public: - explicit ProcessUndecryptablePacketsAlarmDelegate(QuicConnection* connection) - : connection_(connection) {} - ProcessUndecryptablePacketsAlarmDelegate( - const ProcessUndecryptablePacketsAlarmDelegate&) = delete; - ProcessUndecryptablePacketsAlarmDelegate& operator=( - const ProcessUndecryptablePacketsAlarmDelegate&) = delete; + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->connected()); QuicConnection::ScopedPacketFlusher flusher(connection_); connection_->MaybeProcessUndecryptablePackets(); } - - private: - QuicConnection* connection_; }; -class DiscardPreviousOneRttKeysAlarmDelegate : public QuicAlarm::Delegate { +class DiscardPreviousOneRttKeysAlarmDelegate + : public QuicConnectionAlarmDelegate { public: - explicit DiscardPreviousOneRttKeysAlarmDelegate(QuicConnection* connection) - : connection_(connection) {} - DiscardPreviousOneRttKeysAlarmDelegate( - const DiscardPreviousOneRttKeysAlarmDelegate&) = delete; - DiscardPreviousOneRttKeysAlarmDelegate& operator=( - const DiscardPreviousOneRttKeysAlarmDelegate&) = delete; + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->connected()); connection_->DiscardPreviousOneRttKeys(); } - - private: - QuicConnection* connection_; }; -class DiscardZeroRttDecryptionKeysAlarmDelegate : public QuicAlarm::Delegate { +class DiscardZeroRttDecryptionKeysAlarmDelegate + : public QuicConnectionAlarmDelegate { public: - explicit DiscardZeroRttDecryptionKeysAlarmDelegate(QuicConnection* connection) - : connection_(connection) {} - DiscardZeroRttDecryptionKeysAlarmDelegate( - const DiscardZeroRttDecryptionKeysAlarmDelegate&) = delete; - DiscardZeroRttDecryptionKeysAlarmDelegate& operator=( - const DiscardZeroRttDecryptionKeysAlarmDelegate&) = delete; + using QuicConnectionAlarmDelegate::QuicConnectionAlarmDelegate; void OnAlarm() override { QUICHE_DCHECK(connection_->connected()); QUIC_DLOG(INFO) << "0-RTT discard alarm fired"; connection_->RemoveDecrypter(ENCRYPTION_ZERO_RTT); } - - private: - QuicConnection* connection_; }; // When the clearer goes out of scope, the coalesced packet gets cleared. @@ -225,8 +189,8 @@ class ScopedCoalescedPacketClearer { }; // Whether this incoming packet is allowed to replace our connection ID. -bool PacketCanReplaceConnectionId(const QuicPacketHeader& header, - Perspective perspective) { +bool PacketCanReplaceServerConnectionId(const QuicPacketHeader& header, + Perspective perspective) { return perspective == Perspective::IS_CLIENT && header.form == IETF_QUIC_LONG_HEADER_PACKET && header.version.IsKnown() && @@ -256,16 +220,11 @@ QuicConnection::QuicConnection( QuicConnectionId server_connection_id, QuicSocketAddress initial_self_address, QuicSocketAddress initial_peer_address, - QuicConnectionHelperInterface* helper, - QuicAlarmFactory* alarm_factory, - QuicPacketWriter* writer, - bool owns_writer, - Perspective perspective, + QuicConnectionHelperInterface* helper, QuicAlarmFactory* alarm_factory, + QuicPacketWriter* writer, bool owns_writer, Perspective perspective, const ParsedQuicVersionVector& supported_versions) - : framer_(supported_versions, - helper->GetClock()->ApproximateNow(), - perspective, - server_connection_id.length()), + : framer_(supported_versions, helper->GetClock()->ApproximateNow(), + perspective, server_connection_id.length()), current_packet_content_(NO_FRAMES_RECEIVED), is_current_packet_connectivity_probing_(false), has_path_challenge_in_current_packet_(false), @@ -278,16 +237,12 @@ QuicConnection::QuicConnection( encryption_level_(ENCRYPTION_INITIAL), clock_(helper->GetClock()), random_generator_(helper->GetRandomGenerator()), - server_connection_id_(server_connection_id), - client_connection_id_(EmptyQuicConnectionId()), client_connection_id_is_set_(false), direct_peer_address_(initial_peer_address), - default_path_(initial_self_address, - QuicSocketAddress(), + default_path_(initial_self_address, QuicSocketAddress(), /*client_connection_id=*/EmptyQuicConnectionId(), server_connection_id, - /*stateless_reset_token_received=*/false, - /*stateless_reset_token=*/{}), + /*stateless_reset_token=*/absl::nullopt), active_effective_peer_migration_type_(NO_CHANGE), support_key_update_for_connection_(false), last_packet_decrypted_(false), @@ -312,34 +267,25 @@ QuicConnection::QuicConnection( ack_alarm_(alarm_factory_->CreateAlarm(arena_.New(this), &arena_)), retransmission_alarm_(alarm_factory_->CreateAlarm( - arena_.New(this), - &arena_)), - send_alarm_( - alarm_factory_->CreateAlarm(arena_.New(this), - &arena_)), - ping_alarm_( - alarm_factory_->CreateAlarm(arena_.New(this), - &arena_)), + arena_.New(this), &arena_)), + send_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), + ping_alarm_(alarm_factory_->CreateAlarm( + arena_.New(this), &arena_)), mtu_discovery_alarm_(alarm_factory_->CreateAlarm( - arena_.New(this), - &arena_)), + arena_.New(this), &arena_)), process_undecryptable_packets_alarm_(alarm_factory_->CreateAlarm( - arena_.New(this), - &arena_)), + arena_.New(this), &arena_)), discard_previous_one_rtt_keys_alarm_(alarm_factory_->CreateAlarm( - arena_.New(this), - &arena_)), + arena_.New(this), &arena_)), discard_zero_rtt_decryption_keys_alarm_(alarm_factory_->CreateAlarm( arena_.New(this), &arena_)), visitor_(nullptr), debug_visitor_(nullptr), packet_creator_(server_connection_id, &framer_, random_generator_, this), - time_of_last_received_packet_(clock_->ApproximateNow()), - sent_packet_manager_(perspective, - clock_, - random_generator_, - &stats_, + last_received_packet_info_(clock_->ApproximateNow()), + sent_packet_manager_(perspective, clock_, random_generator_, &stats_, GetDefaultCongestionControlType()), version_negotiated_(false), perspective_(perspective), @@ -357,44 +303,20 @@ QuicConnection::QuicConnection( bundle_retransmittable_with_pto_ack_(false), fill_up_link_during_probing_(false), probing_retransmission_pending_(false), - stateless_reset_token_received_(false), - received_stateless_reset_token_({}), last_control_frame_id_(kInvalidControlFrameId), is_path_degrading_(false), processing_ack_frame_(false), supports_release_time_(false), release_time_into_future_(QuicTime::Delta::Zero()), - blackhole_detector_(this, &arena_, alarm_factory_), - idle_network_detector_(this, - clock_->ApproximateNow(), - &arena_, - alarm_factory_), - encrypted_control_frames_( - GetQuicReloadableFlag(quic_encrypted_control_frames)), - use_encryption_level_context_( - encrypted_control_frames_ && - GetQuicReloadableFlag(quic_use_encryption_level_context)), - path_validator_(alarm_factory_, &arena_, this, random_generator_), + blackhole_detector_(this, &arena_, alarm_factory_, &context_), + idle_network_detector_(this, clock_->ApproximateNow(), &arena_, + alarm_factory_, &context_), + path_validator_(alarm_factory_, &arena_, this, random_generator_, + &context_), most_recent_frame_type_(NUM_FRAME_TYPES) { - QUIC_BUG_IF(quic_bug_12714_1, - !start_peer_migration_earlier_ && send_path_response_); - QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT || default_path_.self_address.IsInitialized()); - if (use_encryption_level_context_) { - QUIC_RELOADABLE_FLAG_COUNT(quic_use_encryption_level_context); - } - - support_multiple_connection_ids_ = - version().HasIetfQuicFrames() && - framer_.do_not_synthesize_source_cid_for_short_header() && - GetQuicRestartFlag(quic_use_reference_counted_sesssion_map) && - GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid_v2) && - GetQuicRestartFlag( - quic_dispatcher_support_multiple_cid_per_connection_v2) && - GetQuicReloadableFlag(quic_connection_support_multiple_cids_v4); - QUIC_DLOG(INFO) << ENDPOINT << "Created connection with server connection ID " << server_connection_id << " and version: " << ParsedQuicVersionToString(version()); @@ -422,7 +344,7 @@ QuicConnection::QuicConnection( MaybeEnableMultiplePacketNumberSpacesSupport(); QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT || supported_versions.size() == 1); - InstallInitialCrypters(ServerConnectionId()); + InstallInitialCrypters(default_path_.server_connection_id); // On the server side, version negotiation has been done by the dispatcher, // and the server connection is created with the right version. @@ -453,6 +375,7 @@ void QuicConnection::InstallInitialCrypters(QuicConnectionId connection_id) { } QuicConnection::~QuicConnection() { + QUICHE_DCHECK_GE(stats_.max_egress_mtu, long_term_mtu_); if (owns_writer_) { delete writer_; } @@ -486,9 +409,9 @@ bool QuicConnection::ValidateConfigConnectionIds(const QuicConfig& config) { // Validate initial_source_connection_id. QuicConnectionId expected_initial_source_connection_id; if (perspective_ == Perspective::IS_CLIENT) { - expected_initial_source_connection_id = ServerConnectionId(); + expected_initial_source_connection_id = default_path_.server_connection_id; } else { - expected_initial_source_connection_id = ClientConnectionId(); + expected_initial_source_connection_id = default_path_.client_connection_id; } if (!config.HasReceivedInitialSourceConnectionId() || config.ReceivedInitialSourceConnectionId() != @@ -589,9 +512,12 @@ void QuicConnection::SetFromConfig(const QuicConfig& config) { } else { SetNetworkTimeouts(config.max_time_before_crypto_handshake(), config.max_idle_time_before_crypto_handshake()); + if (config.HasClientRequestedIndependentOption(kNCHP, perspective_)) { + packet_creator_.set_chaos_protection_enabled(false); + } } - if (support_multiple_connection_ids_ && + if (version().HasIetfQuicFrames() && config.HasReceivedPreferredAddressConnectionIdAndToken()) { QuicNewConnectionIdFrame frame; std::tie(frame.connection_id, frame.stateless_reset_token) = @@ -683,37 +609,40 @@ void QuicConnection::SetFromConfig(const QuicConfig& config) { no_stop_waiting_frames_ = true; } if (config.HasReceivedStatelessResetToken()) { - if (use_connection_id_on_default_path_) { - default_path_.stateless_reset_token_received = true; - default_path_.stateless_reset_token = - config.ReceivedStatelessResetToken(); - } else { - stateless_reset_token_received_ = true; - received_stateless_reset_token_ = config.ReceivedStatelessResetToken(); - } + default_path_.stateless_reset_token = config.ReceivedStatelessResetToken(); } if (config.HasReceivedAckDelayExponent()) { framer_.set_peer_ack_delay_exponent(config.ReceivedAckDelayExponent()); } - if (GetQuicReloadableFlag(quic_send_timestamps) && - config.HasClientSentConnectionOption(kSTMP, perspective_)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_send_timestamps); - framer_.set_process_timestamps(true); - uber_received_packet_manager_.set_save_timestamps(true); - } if (config.HasClientSentConnectionOption(kEACK, perspective_)) { bundle_retransmittable_with_pto_ack_ = true; } if (config.HasClientSentConnectionOption(kDFER, perspective_)) { defer_send_in_response_to_packets_ = false; } + const bool remove_connection_migration_connection_option = + GetQuicReloadableFlag(quic_remove_connection_migration_connection_option); + if (remove_connection_migration_connection_option) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_remove_connection_migration_connection_option); + } if (framer_.version().HasIetfQuicFrames() && use_path_validator_ && count_bytes_on_alternative_path_separately_ && GetQuicReloadableFlag(quic_server_reverse_validate_new_path3) && - config.HasClientSentConnectionOption(kRVCM, perspective_)) { + (remove_connection_migration_connection_option || + config.HasClientSentConnectionOption(kRVCM, perspective_))) { QUIC_CODE_COUNT_N(quic_server_reverse_validate_new_path3, 6, 6); validate_client_addresses_ = true; } + // Having connection_migration_use_new_cid_ depends on the same set of flags + // and connection option on both client and server sides has the advantage of: + // 1) Less chance of skew in using new connection ID or not between client + // and server in unit tests with random flag combinations. + // 2) Client side's rollout can be protected by the same connection option. + connection_migration_use_new_cid_ = + validate_client_addresses_ && + GetQuicReloadableFlag(quic_drop_unsent_path_response) && + GetQuicReloadableFlag(quic_connection_migration_use_new_cid_v2); if (config.HasReceivedMaxPacketSize()) { peer_max_packet_size_ = config.ReceivedMaxPacketSize(); MaybeUpdatePacketCreatorMaxPacketLengthAndPadding(); @@ -848,7 +777,8 @@ bool QuicConnection::SelectMutualVersion( framer_.supported_versions(); for (size_t i = 0; i < supported_versions.size(); ++i) { const ParsedQuicVersion& version = supported_versions[i]; - if (QuicContainsValue(available_versions, version)) { + if (std::find(available_versions.begin(), available_versions.end(), + version) != available_versions.end()) { framer_.set_version(version); return true; } @@ -875,7 +805,7 @@ void QuicConnection::OnPublicResetPacket(const QuicPublicResetPacket& packet) { // Check that any public reset packet with a different connection ID that was // routed to this QuicConnection has been redirected before control reaches // here. (Check for a bug regression.) - QUICHE_DCHECK_EQ(ServerConnectionId(), packet.connection_id); + QUICHE_DCHECK_EQ(default_path_.server_connection_id, packet.connection_id); QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); QUICHE_DCHECK(!version().HasIetfInvariantHeader()); if (debug_visitor_ != nullptr) { @@ -913,7 +843,7 @@ void QuicConnection::OnVersionNegotiationPacket( // Check that any public reset packet with a different connection ID that was // routed to this QuicConnection has been redirected before control reaches // here. (Check for a bug regression.) - QUICHE_DCHECK_EQ(ServerConnectionId(), packet.connection_id); + QUICHE_DCHECK_EQ(default_path_.server_connection_id, packet.connection_id); if (perspective_ == Perspective::IS_SERVER) { const std::string error_details = "Server received version negotiation packet."; @@ -932,7 +862,8 @@ void QuicConnection::OnVersionNegotiationPacket( return; } - if (QuicContainsValue(packet.versions, version())) { + if (std::find(packet.versions.begin(), packet.versions.end(), version()) != + packet.versions.end()) { const std::string error_details = absl::StrCat( "Server already supports client's version ", ParsedQuicVersionToString(version()), @@ -964,17 +895,17 @@ void QuicConnection::OnRetryPacket(QuicConnectionId original_connection_id, absl::string_view retry_without_tag) { QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); if (version().UsesTls()) { - if (!CryptoUtils::ValidateRetryIntegrityTag(version(), ServerConnectionId(), - retry_without_tag, - retry_integrity_tag)) { + if (!CryptoUtils::ValidateRetryIntegrityTag( + version(), default_path_.server_connection_id, retry_without_tag, + retry_integrity_tag)) { QUIC_DLOG(ERROR) << "Ignoring RETRY with invalid integrity tag"; return; } } else { - if (original_connection_id != ServerConnectionId()) { + if (original_connection_id != default_path_.server_connection_id) { QUIC_DLOG(ERROR) << "Ignoring RETRY with original connection ID " << original_connection_id << " not matching expected " - << ServerConnectionId() << " token " + << default_path_.server_connection_id << " token " << absl::BytesToHexString(retry_token); return; } @@ -982,10 +913,11 @@ void QuicConnection::OnRetryPacket(QuicConnectionId original_connection_id, framer_.set_drop_incoming_retry_packets(true); stats_.retry_packet_processed = true; QUIC_DLOG(INFO) << "Received RETRY, replacing connection ID " - << ServerConnectionId() << " with " << new_connection_id - << ", received token " << absl::BytesToHexString(retry_token); + << default_path_.server_connection_id << " with " + << new_connection_id << ", received token " + << absl::BytesToHexString(retry_token); if (!original_destination_connection_id_.has_value()) { - original_destination_connection_id_ = ServerConnectionId(); + original_destination_connection_id_ = default_path_.server_connection_id; } QUICHE_DCHECK(!retry_source_connection_id_.has_value()) << retry_source_connection_id_.value(); @@ -994,83 +926,93 @@ void QuicConnection::OnRetryPacket(QuicConnectionId original_connection_id, packet_creator_.SetRetryToken(retry_token); // Reinstall initial crypters because the connection ID changed. - InstallInitialCrypters(ServerConnectionId()); + InstallInitialCrypters(default_path_.server_connection_id); sent_packet_manager_.MarkInitialPacketsForRetransmission(); } -bool QuicConnection::HasIncomingConnectionId(QuicConnectionId connection_id) { - if (quic_deprecate_incoming_connection_ids_) { - QUIC_RELOADABLE_FLAG_COUNT(quic_deprecate_incoming_connection_ids); - // TODO(haoyuewang) Inline this after the flag is deprecated. - return connection_id == original_destination_connection_id_; - } - for (QuicConnectionId const& incoming_connection_id : - incoming_connection_ids_) { - if (incoming_connection_id == connection_id) { - return true; - } - } - return false; -} - void QuicConnection::SetOriginalDestinationConnectionId( const QuicConnectionId& original_destination_connection_id) { QUIC_DLOG(INFO) << "Setting original_destination_connection_id to " << original_destination_connection_id << " on connection with server_connection_id " - << ServerConnectionId(); - QUICHE_DCHECK_NE(original_destination_connection_id, ServerConnectionId()); - if (!quic_deprecate_incoming_connection_ids_) { - if (!HasIncomingConnectionId(original_destination_connection_id)) { - incoming_connection_ids_.push_back(original_destination_connection_id); - } - } + << default_path_.server_connection_id; + QUICHE_DCHECK_NE(original_destination_connection_id, + default_path_.server_connection_id); InstallInitialCrypters(original_destination_connection_id); QUICHE_DCHECK(!original_destination_connection_id_.has_value()) << original_destination_connection_id_.value(); original_destination_connection_id_ = original_destination_connection_id; + original_destination_connection_id_replacement_ = + default_path_.server_connection_id; } QuicConnectionId QuicConnection::GetOriginalDestinationConnectionId() { if (original_destination_connection_id_.has_value()) { return original_destination_connection_id_.value(); } - return ServerConnectionId(); + return default_path_.server_connection_id; +} + +bool QuicConnection::ValidateServerConnectionId( + const QuicPacketHeader& header) const { + if (perspective_ == Perspective::IS_CLIENT && + header.form == IETF_QUIC_SHORT_HEADER_PACKET) { + return true; + } + + QuicConnectionId server_connection_id = + GetServerConnectionIdAsRecipient(header, perspective_); + + if (server_connection_id == default_path_.server_connection_id || + server_connection_id == original_destination_connection_id_) { + return true; + } + + if (PacketCanReplaceServerConnectionId(header, perspective_)) { + QUIC_DLOG(INFO) << ENDPOINT << "Accepting packet with new connection ID " + << server_connection_id << " instead of " + << default_path_.server_connection_id; + return true; + } + + if (connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_SERVER && + self_issued_cid_manager_ != nullptr && + self_issued_cid_manager_->IsConnectionIdInUse(server_connection_id)) { + return true; + } + + return false; } bool QuicConnection::OnUnauthenticatedPublicHeader( const QuicPacketHeader& header) { last_packet_destination_connection_id_ = header.destination_connection_id; + // If last packet destination connection ID is the original server + // connection ID chosen by client, replaces it with the connection ID chosen + // by server. + if (perspective_ == Perspective::IS_SERVER && + original_destination_connection_id_.has_value() && + last_packet_destination_connection_id_ == + *original_destination_connection_id_) { + last_packet_destination_connection_id_ = + original_destination_connection_id_replacement_; + } // As soon as we receive an initial we start ignoring subsequent retries. if (header.version_flag && header.long_packet_type == INITIAL) { framer_.set_drop_incoming_retry_packets(true); } - bool skip_server_connection_id_validation = - framer_.do_not_synthesize_source_cid_for_short_header() && - perspective_ == Perspective::IS_CLIENT && - header.form == IETF_QUIC_SHORT_HEADER_PACKET; - - QuicConnectionId server_connection_id = - GetServerConnectionIdAsRecipient(header, perspective_); - - if (!skip_server_connection_id_validation && - server_connection_id != ServerConnectionId() && - !HasIncomingConnectionId(server_connection_id)) { - if (PacketCanReplaceConnectionId(header, perspective_)) { - QUIC_DLOG(INFO) << ENDPOINT << "Accepting packet with new connection ID " - << server_connection_id << " instead of " - << ServerConnectionId(); - return true; - } - + if (!ValidateServerConnectionId(header)) { ++stats_.packets_dropped; + QuicConnectionId server_connection_id = + GetServerConnectionIdAsRecipient(header, perspective_); QUIC_DLOG(INFO) << ENDPOINT << "Ignoring packet from unexpected server connection ID " << server_connection_id << " instead of " - << ServerConnectionId(); + << default_path_.server_connection_id; if (debug_visitor_ != nullptr) { debug_visitor_->OnIncorrectConnectionId(server_connection_id); } @@ -1085,18 +1027,15 @@ bool QuicConnection::OnUnauthenticatedPublicHeader( return true; } - if (framer_.do_not_synthesize_source_cid_for_short_header() && - perspective_ == Perspective::IS_SERVER && + if (perspective_ == Perspective::IS_SERVER && header.form == IETF_QUIC_SHORT_HEADER_PACKET) { - QUIC_RELOADABLE_FLAG_COUNT_N( - quic_do_not_synthesize_source_cid_for_short_header, 3, 3); return true; } QuicConnectionId client_connection_id = GetClientConnectionIdAsRecipient(header, perspective_); - if (client_connection_id == ClientConnectionId()) { + if (client_connection_id == default_path_.client_connection_id) { return true; } @@ -1108,11 +1047,18 @@ bool QuicConnection::OnUnauthenticatedPublicHeader( return true; } + if (connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_CLIENT && + self_issued_cid_manager_ != nullptr && + self_issued_cid_manager_->IsConnectionIdInUse(client_connection_id)) { + return true; + } + ++stats_.packets_dropped; QUIC_DLOG(INFO) << ENDPOINT << "Ignoring packet from unexpected client connection ID " << client_connection_id << " instead of " - << ClientConnectionId(); + << default_path_.client_connection_id; return false; } @@ -1121,17 +1067,8 @@ bool QuicConnection::OnUnauthenticatedHeader(const QuicPacketHeader& header) { debug_visitor_->OnUnauthenticatedHeader(header); } - // Check that any public reset packet with a different connection ID that was - // routed to this QuicConnection has been redirected before control reaches - // here. - QUICHE_DCHECK((framer_.do_not_synthesize_source_cid_for_short_header() && - perspective_ == Perspective::IS_CLIENT && - header.form == IETF_QUIC_SHORT_HEADER_PACKET) || - GetServerConnectionIdAsRecipient(header, perspective_) == - ServerConnectionId() || - HasIncomingConnectionId( - GetServerConnectionIdAsRecipient(header, perspective_)) || - PacketCanReplaceConnectionId(header, perspective_)); + // Sanity check on the server connection ID in header. + QUICHE_DCHECK(ValidateServerConnectionId(header)); if (packet_creator_.HasPendingFrames()) { // Incoming packets may change a queued ACK frame. @@ -1198,6 +1135,10 @@ bool QuicConnection::HasPendingAcks() const { return ack_alarm_->IsSet(); } +void QuicConnection::OnUserAgentIdKnown(const std::string& /*user_agent_id*/) { + sent_packet_manager_.OnUserAgentIdKnown(); +} + void QuicConnection::OnDecryptedPacket(size_t /*length*/, EncryptionLevel level) { last_decrypted_packet_level_ = level; @@ -1224,7 +1165,8 @@ void QuicConnection::OnDecryptedPacket(size_t /*length*/, default_path_.validated = true; stats_.address_validated_via_decrypting_packet = true; } - idle_network_detector_.OnPacketReceived(time_of_last_received_packet_); + idle_network_detector_.OnPacketReceived( + last_received_packet_info_.receipt_time); visitor_->OnPacketDecrypted(level); } @@ -1233,7 +1175,7 @@ QuicSocketAddress QuicConnection::GetEffectivePeerAddressFromCurrentPacket() const { // By default, the connection is not proxied, and the effective peer address // is the packet's source address, i.e. the direct peer address. - return last_packet_source_address_; + return last_received_packet_info_.source_address; } bool QuicConnection::OnPacketHeader(const QuicPacketHeader& header) { @@ -1263,7 +1205,7 @@ bool QuicConnection::OnPacketHeader(const QuicPacketHeader& header) { // for client connections. // TODO(fayang): only change peer addresses in application data packet // number space. - UpdatePeerAddress(last_packet_source_address_); + UpdatePeerAddress(last_received_packet_info_.source_address); default_path_.peer_address = GetEffectivePeerAddressFromCurrentPacket(); } } else { @@ -1283,6 +1225,44 @@ bool QuicConnection::OnPacketHeader(const QuicPacketHeader& header) { default_path_.peer_address, GetEffectivePeerAddressFromCurrentPacket()); + if (connection_migration_use_new_cid_) { + auto effective_peer_address = GetEffectivePeerAddressFromCurrentPacket(); + // Since server does not send new connection ID to client before handshake + // completion and source connection ID is omitted in short header packet, + // the server_connection_id on PathState on the server side does not + // affect the packets server writes after handshake completion. On the + // other hand, it is still desirable to have the "correct" server + // connection ID set on path. + // 1) If client uses 1 unique server connection ID per path and the packet + // is received from an existing path, then + // last_packet_destination_connection_id_ will always be the same as the + // server connection ID on path. Server side will maintain the 1-to-1 + // mapping from server connection ID to path. + // 2) If client uses multiple server connection IDs on the same path, + // compared to the server_connection_id on path, + // last_packet_destination_connection_id_ has the advantage that it is + // still present in the session map since the packet can be routed here + // regardless of packet reordering. + if (IsDefaultPath(last_received_packet_info_.destination_address, + effective_peer_address)) { + default_path_.server_connection_id = + last_packet_destination_connection_id_; + } else if (IsAlternativePath( + last_received_packet_info_.destination_address, + effective_peer_address)) { + alternative_path_.server_connection_id = + last_packet_destination_connection_id_; + } + } + + if (last_packet_destination_connection_id_ != + default_path_.server_connection_id && + (!original_destination_connection_id_.has_value() || + last_packet_destination_connection_id_ != + *original_destination_connection_id_)) { + QUIC_CODE_COUNT(quic_connection_id_change); + } + QUIC_DLOG_IF(INFO, current_effective_peer_migration_type_ != NO_CHANGE) << ENDPOINT << "Effective peer's ip:port changed from " << default_path_.peer_address.ToString() << " to " @@ -1300,20 +1280,22 @@ bool QuicConnection::OnPacketHeader(const QuicPacketHeader& header) { // Record packet receipt to populate ack info before processing stream // frames, since the processing may result in sending a bundled ack. + QuicTime receipt_time = idle_network_detector_.time_of_last_received_packet(); + if (reset_per_packet_state_for_undecryptable_packets_ && + SupportsMultiplePacketNumberSpaces()) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_reset_per_packet_state_for_undecryptable_packets, 2, 2); + receipt_time = last_received_packet_info_.receipt_time; + } uber_received_packet_manager_.RecordPacketReceived( - last_decrypted_packet_level_, last_header_, - idle_network_detector_.time_of_last_received_packet()); - if (GetQuicReloadableFlag(quic_enable_token_based_address_validation)) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_enable_token_based_address_validation, 2, - 2); - if (EnforceAntiAmplificationLimit() && !IsHandshakeConfirmed() && - !header.retry_token.empty() && - visitor_->ValidateToken(header.retry_token)) { - QUIC_DLOG(INFO) << ENDPOINT << "Address validated via token."; - QUIC_CODE_COUNT(quic_address_validated_via_token); - default_path_.validated = true; - stats_.address_validated_via_token = true; - } + last_decrypted_packet_level_, last_header_, receipt_time); + if (EnforceAntiAmplificationLimit() && !IsHandshakeConfirmed() && + !header.retry_token.empty() && + visitor_->ValidateToken(header.retry_token)) { + QUIC_DLOG(INFO) << ENDPOINT << "Address validated via token."; + QUIC_CODE_COUNT(quic_address_validated_via_token); + default_path_.validated = true; + stats_.address_validated_via_token = true; } QUICHE_DCHECK(connected_); return true; @@ -1352,6 +1334,9 @@ bool QuicConnection::OnStreamFrame(const QuicStreamFrame& frame) { ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); return false; } + // TODO(fayang): Consider moving UpdatePacketContent and + // MaybeUpdateAckTimeout to a stand-alone function instead of calling them for + // all frames. MaybeUpdateAckTimeout(); visitor_->OnStreamFrame(frame); stats_.stream_bytes_received += frame.data_length; @@ -1424,6 +1409,10 @@ bool QuicConnection::OnAckFrameStart(QuicPacketNumber largest_acked, sent_packet_manager_.OnAckFrameStart( largest_acked, ack_delay_time, idle_network_detector_.time_of_last_received_packet()); + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnAckFrameStart(ack_delay_time); + } return true; } @@ -1440,6 +1429,10 @@ bool QuicConnection::OnAckRange(QuicPacketNumber start, QuicPacketNumber end) { } sent_packet_manager_.OnAckRange(start, end); + + if (debug_visitor_ != nullptr) { + debug_visitor_->OnAckRange(start, end); + } return true; } @@ -1468,6 +1461,10 @@ bool QuicConnection::OnAckFrameEnd(QuicPacketNumber start) { << most_recent_frame_type_; QUIC_DVLOG(1) << ENDPOINT << "OnAckFrameEnd, start: " << start; + if (debug_visitor_ != nullptr) { + debug_visitor_->OnAckFrameEnd(start); + } + if (GetLargestReceivedPacketWithAck().IsInitialized() && last_header_.packet_number <= GetLargestReceivedPacketWithAck()) { QUIC_DLOG(INFO) << ENDPOINT << "Received an old ack frame: ignoring"; @@ -1660,7 +1657,7 @@ bool QuicConnection::OnStopSendingFrame(const QuicStopSendingFrame& frame) { QUIC_DLOG(INFO) << ENDPOINT << "STOP_SENDING frame received for stream: " << frame.stream_id << " with error: " << frame.ietf_error_code; - + MaybeUpdateAckTimeout(); visitor_->OnStopSendingFrame(frame); return connected_; } @@ -1698,17 +1695,8 @@ bool QuicConnection::OnPathChallengeFrame(const QuicPathChallengeFrame& frame) { } QUIC_CODE_COUNT_N(quic_server_reverse_validate_new_path3, 1, 6); { - // UpdatePacketStateAndReplyPathChallenge() may start reverse path - // validation, if so bundle the PATH_CHALLENGE together with the - // PATH_RESPONSE. This context needs to be out of scope before returning. // TODO(danzh) inline OnPathChallengeFrameInternal() once - // support_reverse_path_validation_ is deprecated. - auto context = - group_path_response_and_challenge_sending_closer_ - ? nullptr - : std::make_unique( - &packet_creator_, last_packet_source_address_, - /*update_connection_id=*/false); + // validate_client_addresses_ is deprecated. if (!OnPathChallengeFrameInternal(frame)) { return false; } @@ -1727,17 +1715,16 @@ bool QuicConnection::OnPathChallengeFrameInternal( debug_visitor_->OnPathChallengeFrame(frame); } - std::unique_ptr context; - if (group_path_response_and_challenge_sending_closer_) { - context = std::make_unique( - &packet_creator_, last_packet_source_address_, - /*update_connection_id=*/false); - } + const QuicSocketAddress current_effective_peer_address = + GetEffectivePeerAddressFromCurrentPacket(); + QuicConnectionId client_cid, server_cid; + FindOnPathConnectionIds(last_received_packet_info_.destination_address, + current_effective_peer_address, &client_cid, + &server_cid); + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, last_received_packet_info_.source_address, client_cid, + server_cid, connection_migration_use_new_cid_); if (should_proactively_validate_peer_address_on_path_challenge_) { - QUIC_RELOADABLE_FLAG_COUNT( - quic_group_path_response_and_challenge_sending_closer); - QuicSocketAddress current_effective_peer_address = - GetEffectivePeerAddressFromCurrentPacket(); // Conditions to proactively validate peer address: // The perspective is server // The PATH_CHALLENGE is received on an unvalidated alternative path. @@ -1745,8 +1732,10 @@ bool QuicConnection::OnPathChallengeFrameInternal( // higher prority. QUIC_DVLOG(1) << "Proactively validate the effective peer address " << current_effective_peer_address; + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 2, 6); ValidatePath(std::make_unique( - default_path_.self_address, current_effective_peer_address, + default_path_.self_address, + last_received_packet_info_.source_address, current_effective_peer_address, this), std::make_unique( this, peer_address())); @@ -1765,8 +1754,9 @@ bool QuicConnection::OnPathChallengeFrameInternal( // Queue or send PATH_RESPONSE. Send PATH_RESPONSE to the source address of // the current incoming packet, even if it's not the default path or the // alternative path. - const bool success = - SendPathResponse(frame.data_buffer, last_packet_source_address_); + const bool success = SendPathResponse( + frame.data_buffer, last_received_packet_info_.source_address, + current_effective_peer_address); if (GetQuicReloadableFlag(quic_drop_unsent_path_response)) { QUIC_RELOADABLE_FLAG_COUNT(quic_drop_unsent_path_response); } @@ -1775,7 +1765,7 @@ bool QuicConnection::OnPathChallengeFrameInternal( if (!GetQuicReloadableFlag(quic_drop_unsent_path_response)) { // Queue the payloads to re-try later. pending_path_challenge_payloads_.push_back( - {frame.data_buffer, last_packet_source_address_}); + {frame.data_buffer, last_received_packet_info_.source_address}); } } // TODO(b/150095588): change the stats to @@ -1800,8 +1790,8 @@ bool QuicConnection::OnPathResponseFrame(const QuicPathResponseFrame& frame) { MaybeUpdateAckTimeout(); if (use_path_validator_) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_pass_path_response_to_validator, 1, 4); - path_validator_.OnPathResponse(frame.data_buffer, - last_packet_destination_address_); + path_validator_.OnPathResponse( + frame.data_buffer, last_received_packet_info_.destination_address); } else { if (!transmitted_connectivity_probe_payload_ || *transmitted_connectivity_probe_payload_ != frame.data_buffer) { @@ -1843,7 +1833,10 @@ bool QuicConnection::OnConnectionCloseFrame( << connection_id() << ", with error: " << QuicErrorCodeToString(frame.quic_error_code) << " (" << frame.error_details << ")" - << ", transport error code: " << frame.wire_error_code + << ", transport error code: " + << QuicIetfTransportErrorCodeString( + static_cast( + frame.wire_error_code)) << ", error frame type: " << frame.transport_close_frame_type; break; @@ -1877,6 +1870,7 @@ bool QuicConnection::OnMaxStreamsFrame(const QuicMaxStreamsFrame& frame) { if (debug_visitor_ != nullptr) { debug_visitor_->OnMaxStreamsFrame(frame); } + MaybeUpdateAckTimeout(); return visitor_->OnMaxStreamsFrame(frame) && connected_; } @@ -1893,6 +1887,7 @@ bool QuicConnection::OnStreamsBlockedFrame( if (debug_visitor_ != nullptr) { debug_visitor_->OnStreamsBlockedFrame(frame); } + MaybeUpdateAckTimeout(); return visitor_->OnStreamsBlockedFrame(frame) && connected_; } @@ -1947,12 +1942,14 @@ void QuicConnection::OnClientConnectionIdAvailable() { return; } if (default_path_.client_connection_id.IsEmpty()) { + // Count client connection ID patched onto the default path. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 3, + 6); const QuicConnectionIdData* unused_cid_data = peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); QUIC_DVLOG(1) << ENDPOINT << "Patch connection ID " << unused_cid_data->connection_id << " to default path"; default_path_.client_connection_id = unused_cid_data->connection_id; - default_path_.stateless_reset_token_received = true; default_path_.stateless_reset_token = unused_cid_data->stateless_reset_token; QUICHE_DCHECK(!packet_creator_.HasPendingFrames()); @@ -1962,20 +1959,43 @@ void QuicConnection::OnClientConnectionIdAvailable() { } if (alternative_path_.peer_address.IsInitialized() && alternative_path_.client_connection_id.IsEmpty()) { + // Count client connection ID patched onto the alternative path. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 4, + 6); const QuicConnectionIdData* unused_cid_data = peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); QUIC_DVLOG(1) << ENDPOINT << "Patch connection ID " << unused_cid_data->connection_id << " to alternative path"; alternative_path_.client_connection_id = unused_cid_data->connection_id; - alternative_path_.stateless_reset_token_received = true; alternative_path_.stateless_reset_token = unused_cid_data->stateless_reset_token; } } +bool QuicConnection::ShouldSetRetransmissionAlarmOnPacketSent( + bool in_flight, EncryptionLevel level) const { + if (!retransmission_alarm_->IsSet()) { + return true; + } + if (!in_flight) { + return false; + } + + if (!SupportsMultiplePacketNumberSpaces()) { + return true; + } + // Before handshake gets confirmed, do not re-arm PTO timer on application + // data. Think about this scenario: on the client side, the CHLO gets + // acknowledged and the SHLO is not received yet. The PTO alarm is set when + // the CHLO acknowledge is received (and there is no in flight INITIAL + // packet). Re-arming PTO alarm on 0-RTT packet would keep postponing the PTO + // alarm. + return IsHandshakeConfirmed() || level == ENCRYPTION_INITIAL || + level == ENCRYPTION_HANDSHAKE; +} + bool QuicConnection::OnNewConnectionIdFrameInner( const QuicNewConnectionIdFrame& frame) { - QUICHE_DCHECK(support_multiple_connection_ids_); if (peer_issued_cid_manager_ == nullptr) { CloseConnection( IETF_QUIC_PROTOCOL_VIOLATION, @@ -1991,16 +2011,16 @@ bool QuicConnection::OnNewConnectionIdFrameInner( ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); return false; } - if (use_connection_id_on_default_path_ && - perspective_ == Perspective::IS_SERVER) { + if (perspective_ == Perspective::IS_SERVER) { OnClientConnectionIdAvailable(); } - QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_support_multiple_cids_v4, 1, 2); + MaybeUpdateAckTimeout(); return true; } bool QuicConnection::OnNewConnectionIdFrame( const QuicNewConnectionIdFrame& frame) { + QUICHE_DCHECK(version().HasIetfQuicFrames()); QUIC_BUG_IF(quic_bug_10511_13, !connected_) << "Processing NEW_CONNECTION_ID frame when " "connection is closed. Last frame: " @@ -2012,14 +2032,12 @@ bool QuicConnection::OnNewConnectionIdFrame( if (debug_visitor_ != nullptr) { debug_visitor_->OnNewConnectionIdFrame(frame); } - if (!support_multiple_connection_ids_) { - return true; - } return OnNewConnectionIdFrameInner(frame); } bool QuicConnection::OnRetireConnectionIdFrame( const QuicRetireConnectionIdFrame& frame) { + QUICHE_DCHECK(version().HasIetfQuicFrames()); QUIC_BUG_IF(quic_bug_10511_14, !connected_) << "Processing RETIRE_CONNECTION_ID frame when " "connection is closed. Last frame: " @@ -2031,13 +2049,10 @@ bool QuicConnection::OnRetireConnectionIdFrame( if (debug_visitor_ != nullptr) { debug_visitor_->OnRetireConnectionIdFrame(frame); } - if (use_connection_id_on_default_path_) { + if (!connection_migration_use_new_cid_) { // Do not respond to RetireConnectionId frame. return true; } - if (!support_multiple_connection_ids_) { - return true; - } if (self_issued_cid_manager_ == nullptr) { CloseConnection( IETF_QUIC_PROTOCOL_VIOLATION, @@ -2053,7 +2068,9 @@ bool QuicConnection::OnRetireConnectionIdFrame( ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); return false; } - QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_support_multiple_cids_v4, 2, 2); + // Count successfully received RETIRE_CONNECTION_ID frames. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 5, 6); + MaybeUpdateAckTimeout(); return true; } @@ -2068,17 +2085,14 @@ bool QuicConnection::OnNewTokenFrame(const QuicNewTokenFrame& frame) { if (debug_visitor_ != nullptr) { debug_visitor_->OnNewTokenFrame(frame); } - if (GetQuicReloadableFlag(quic_enable_token_based_address_validation)) { - if (perspective_ == Perspective::IS_SERVER) { - CloseConnection(QUIC_INVALID_NEW_TOKEN, - "Server received new token frame.", - ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); - return false; - } - // NEW_TOKEN frame should insitgate ACKs. - MaybeUpdateAckTimeout(); - visitor_->OnNewTokenReceived(frame.token); + if (perspective_ == Perspective::IS_SERVER) { + CloseConnection(QUIC_INVALID_NEW_TOKEN, "Server received new token frame.", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return false; } + // NEW_TOKEN frame should insitgate ACKs. + MaybeUpdateAckTimeout(); + visitor_->OnNewTokenReceived(frame.token); return true; } @@ -2187,6 +2201,9 @@ bool QuicConnection::OnBlockedFrame(const QuicBlockedFrame& frame) { } void QuicConnection::OnPacketComplete() { + if (debug_visitor_ != nullptr) { + debug_visitor_->OnPacketComplete(); + } // Don't do anything if this packet closed the connection. if (!connected_) { ClearLastFrames(); @@ -2236,8 +2253,8 @@ void QuicConnection::MaybeRespondToConnectivityProbingOrMigration() { if (perspective_ == Perspective::IS_CLIENT) { // This node is a client, notify that a speculative connectivity probing // packet has been received anyway. - visitor_->OnPacketReceived(last_packet_destination_address_, - last_packet_source_address_, + visitor_->OnPacketReceived(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address, /*is_connectivity_probe=*/false); return; } @@ -2250,54 +2267,43 @@ void QuicConnection::MaybeRespondToConnectivityProbingOrMigration() { // If the packet contains PATH CHALLENGE, send appropriate RESPONSE. // There was at least one PATH CHALLENGE in the received packet, // Generate the required PATH RESPONSE. - SendGenericPathProbePacket(nullptr, last_packet_source_address_, + SendGenericPathProbePacket(nullptr, + last_received_packet_info_.source_address, /* is_response=*/true); return; } } else { if (IsCurrentPacketConnectivityProbing()) { - visitor_->OnPacketReceived(last_packet_destination_address_, - last_packet_source_address_, + visitor_->OnPacketReceived(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address, /*is_connectivity_probe=*/true); return; } if (perspective_ == Perspective::IS_CLIENT) { // This node is a client, notify that a speculative connectivity probing // packet has been received anyway. - QUIC_DVLOG(1) << ENDPOINT - << "Received a speculative connectivity probing packet for " - << GetServerConnectionIdAsRecipient(last_header_, - perspective_) - << " from ip:port: " - << last_packet_source_address_.ToString() << " to ip:port: " - << last_packet_destination_address_.ToString(); - visitor_->OnPacketReceived(last_packet_destination_address_, - last_packet_source_address_, + QUIC_DVLOG(1) + << ENDPOINT + << "Received a speculative connectivity probing packet for " + << GetServerConnectionIdAsRecipient(last_header_, perspective_) + << " from ip:port: " + << last_received_packet_info_.source_address.ToString() + << " to ip:port: " + << last_received_packet_info_.destination_address.ToString(); + visitor_->OnPacketReceived(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address, /*is_connectivity_probe=*/false); return; } } - // Server starts to migrate connection upon receiving of non-probing packet - // from a new peer address. - if (!start_peer_migration_earlier_ && - last_header_.packet_number == GetLargestReceivedPacket()) { - direct_peer_address_ = last_packet_source_address_; - if (current_effective_peer_migration_type_ != NO_CHANGE) { - // TODO(fayang): When multiple packet number spaces is supported, only - // start peer migration for the application data. - StartEffectivePeerMigration(current_effective_peer_migration_type_); - } - } } bool QuicConnection::IsValidStatelessResetToken( const StatelessResetToken& token) const { - if (use_connection_id_on_default_path_) { - return default_path_.stateless_reset_token_received && - token == default_path_.stateless_reset_token; - } - return stateless_reset_token_received_ && - token == received_stateless_reset_token_; + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + return default_path_.stateless_reset_token.has_value() && + QuicUtils::AreStatelessResetTokensEqual( + token, *default_path_.stateless_reset_token); } void QuicConnection::OnAuthenticatedIetfStatelessResetPacket( @@ -2309,10 +2315,10 @@ void QuicConnection::OnAuthenticatedIetfStatelessResetPacket( if (use_path_validator_) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_pass_path_response_to_validator, 4, 4); - if (!IsDefaultPath(last_packet_destination_address_, - last_packet_source_address_)) { + if (!IsDefaultPath(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address)) { // This packet is received on a probing path. Do not close connection. - if (IsAlternativePath(last_packet_destination_address_, + if (IsAlternativePath(last_received_packet_info_.destination_address, GetEffectivePeerAddressFromCurrentPacket())) { QUIC_BUG_IF(quic_bug_12714_18, alternative_path_.validated) << "STATELESS_RESET received on alternate path after it's " @@ -2324,8 +2330,9 @@ void QuicConnection::OnAuthenticatedIetfStatelessResetPacket( } return; } - } else if (!visitor_->ValidateStatelessReset(last_packet_destination_address_, - last_packet_source_address_)) { + } else if (!visitor_->ValidateStatelessReset( + last_received_packet_info_.destination_address, + last_received_packet_info_.source_address)) { // This packet is received on a probing path. Do not close connection. return; } @@ -2380,23 +2387,13 @@ void QuicConnection::ClearLastFrames() { } void QuicConnection::CloseIfTooManyOutstandingSentPackets() { - bool should_close; - if (GetQuicReloadableFlag( - quic_close_connection_with_too_many_outstanding_packets)) { - QUIC_RELOADABLE_FLAG_COUNT( - quic_close_connection_with_too_many_outstanding_packets); - should_close = - sent_packet_manager_.GetLargestSentPacket().IsInitialized() && - sent_packet_manager_.GetLargestSentPacket() > - sent_packet_manager_.GetLeastUnacked() + max_tracked_packets_; - } else { - should_close = - sent_packet_manager_.GetLargestObserved().IsInitialized() && - sent_packet_manager_.GetLargestObserved() > - sent_packet_manager_.GetLeastUnacked() + max_tracked_packets_; - } // This occurs if we don't discard old packets we've seen fast enough. It's // possible largest observed is less than leaset unacked. + const bool should_close = + sent_packet_manager_.GetLargestSentPacket().IsInitialized() && + sent_packet_manager_.GetLargestSentPacket() > + sent_packet_manager_.GetLeastUnacked() + max_tracked_packets_; + if (should_close) { CloseConnection( QUIC_TOO_MANY_OUTSTANDING_SENT_PACKETS, @@ -2514,28 +2511,20 @@ QuicConsumedData QuicConnection::SendStreamData(QuicStreamId id, QuicUtils::IsCryptoStreamId(transport_version(), id)) { MaybeActivateLegacyVersionEncapsulation(); } - if (GetQuicReloadableFlag(quic_preempt_stream_data_with_handshake_packet) && - perspective_ == Perspective::IS_SERVER && + if (perspective_ == Perspective::IS_SERVER && version().CanSendCoalescedPackets() && !IsHandshakeConfirmed()) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_preempt_stream_data_with_handshake_packet, - 1, 2); - if (GetQuicReloadableFlag(quic_donot_pto_half_rtt_data)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_donot_pto_half_rtt_data); - if (in_on_retransmission_time_out_ && - coalesced_packet_.NumberOfPackets() == 0u) { - // PTO fires while handshake is not confirmed. Do not preempt handshake - // data with stream data. - QUIC_CODE_COUNT(quic_try_to_send_half_rtt_data_when_pto_fires); - return QuicConsumedData(0, false); - } + if (in_on_retransmission_time_out_ && + coalesced_packet_.NumberOfPackets() == 0u) { + // PTO fires while handshake is not confirmed. Do not preempt handshake + // data with stream data. + QUIC_CODE_COUNT(quic_try_to_send_half_rtt_data_when_pto_fires); + return QuicConsumedData(0, false); } if (coalesced_packet_.ContainsPacketOfEncryptionLevel(ENCRYPTION_INITIAL) && coalesced_packet_.NumberOfPackets() == 1u) { // Handshake is not confirmed yet, if there is only an initial packet in // the coalescer, try to bundle an ENCRYPTION_HANDSHAKE packet before // sending stream data. - QUIC_RELOADABLE_FLAG_COUNT_N( - quic_preempt_stream_data_with_handshake_packet, 2, 2); sent_packet_manager_.RetransmitDataOfSpaceIfAny(HANDSHAKE_DATA); } } @@ -2620,8 +2609,8 @@ const QuicConnectionStats& QuicConnection::GetStats() { stats_.estimated_bandwidth = sent_packet_manager_.BandwidthEstimate(); sent_packet_manager_.GetSendAlgorithm()->PopulateConnectionStats(&stats_); - stats_.max_packet_size = packet_creator_.max_packet_length(); - stats_.max_received_packet_size = largest_received_packet_size_; + stats_.egress_mtu = long_term_mtu_; + stats_.ingress_mtu = largest_received_packet_size_; return stats_; } @@ -2690,9 +2679,13 @@ void QuicConnection::OnUndecryptablePacket(const QuicEncryptedPacket& packet, } bool QuicConnection::ShouldEnqueueUnDecryptablePacket( - EncryptionLevel decryption_level, - bool has_decryption_key) const { - if (encryption_level_ == ENCRYPTION_FORWARD_SECURE) { + EncryptionLevel decryption_level, bool has_decryption_key) const { + if (has_decryption_key) { + // We already have the key for this decryption level, therefore no + // future keys will allow it be decrypted. + return false; + } + if (IsHandshakeComplete()) { // We do not expect to install any further keys. return false; } @@ -2700,15 +2693,17 @@ bool QuicConnection::ShouldEnqueueUnDecryptablePacket( // We do not queue more than max_undecryptable_packets_ packets. return false; } - if (has_decryption_key) { - // We already have the key for this decryption level, therefore no - // future keys will allow it be decrypted. + if (version().KnowsWhichDecrypterToUse() && + decryption_level == ENCRYPTION_INITIAL) { + // When the corresponding decryption key is not available, all + // non-Initial packets should be buffered until the handshake is complete. return false; } - if (version().KnowsWhichDecrypterToUse() && - decryption_level <= encryption_level_) { - // On versions that know which decrypter to use, we install keys in order - // so we will not get newer keys for lower encryption levels. + if (perspective_ == Perspective::IS_CLIENT && version().UsesTls() && + decryption_level == ENCRYPTION_ZERO_RTT) { + // Only clients send Zero RTT packets in IETF QUIC. + QUIC_PEER_BUG(quic_peer_bug_client_received_zero_rtt) + << "Client received a Zero RTT packet, not buffering."; return false; } return true; @@ -2763,18 +2758,17 @@ void QuicConnection::ProcessUdpPacket(const QuicSocketAddress& self_address, if (debug_visitor_ != nullptr) { debug_visitor_->OnPacketReceived(self_address, peer_address, packet); } - current_incoming_packet_received_bytes_counted_ = false; + last_received_packet_info_ = + ReceivedPacketInfo(self_address, peer_address, packet.receipt_time()); last_size_ = packet.length(); current_packet_data_ = packet.data(); - last_packet_destination_address_ = self_address; - last_packet_source_address_ = peer_address; if (!default_path_.self_address.IsInitialized()) { - default_path_.self_address = last_packet_destination_address_; + default_path_.self_address = last_received_packet_info_.destination_address; } if (!direct_peer_address_.IsInitialized()) { - UpdatePeerAddress(last_packet_source_address_); + UpdatePeerAddress(last_received_packet_info_.source_address); } if (!default_path_.peer_address.IsInitialized()) { @@ -2795,11 +2789,11 @@ void QuicConnection::ProcessUdpPacket(const QuicSocketAddress& self_address, if (EnforceAntiAmplificationLimit()) { default_path_.bytes_received_before_address_validation += last_size_; } - } else if (IsDefaultPath(last_packet_destination_address_, - last_packet_source_address_) && + } else if (IsDefaultPath(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address) && EnforceAntiAmplificationLimit()) { QUIC_CODE_COUNT_N(quic_count_bytes_on_alternative_path_seperately, 1, 5); - current_incoming_packet_received_bytes_counted_ = true; + last_received_packet_info_.received_bytes_counted = true; default_path_.bytes_received_before_address_validation += last_size_; } @@ -2811,10 +2805,9 @@ void QuicConnection::ProcessUdpPacket(const QuicSocketAddress& self_address, << " too far from current time:" << clock_->ApproximateNow().ToDebuggingValue(); } - time_of_last_received_packet_ = packet.receipt_time(); QUIC_DVLOG(1) << ENDPOINT << "time of last received packet: " << packet.receipt_time().ToDebuggingValue() << " from peer " - << last_packet_source_address_; + << last_received_packet_info_.source_address; ScopedPacketFlusher flusher(this); if (!framer_.ProcessPacket(packet)) { @@ -2848,9 +2841,7 @@ void QuicConnection::ProcessUdpPacket(const QuicSocketAddress& self_address, } } - const bool processed = MaybeProcessCoalescedPackets(); - if (!donot_write_mid_packet_processing_ || !processed) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_donot_write_mid_packet_processing, 3, 3); + if (!MaybeProcessCoalescedPackets()) { MaybeProcessUndecryptablePackets(); MaybeSendInResponseToPacket(); } @@ -2900,8 +2891,12 @@ void QuicConnection::OnCanWrite() { QUIC_RELOADABLE_FLAG_COUNT_N(quic_send_path_response2, 4, 5); const PendingPathChallenge& pending_path_challenge = pending_path_challenge_payloads_.front(); + // Note connection_migration_use_cid_ will depends on + // quic_drop_unsent_path_response flag eventually, and hence the empty + // effective_peer_address here will not be used. if (!SendPathResponse(pending_path_challenge.received_path_challenge, - pending_path_challenge.peer_address)) { + pending_path_challenge.peer_address, + /*effective_peer_address=*/QuicSocketAddress())) { break; } pending_path_challenge_payloads_.pop_front(); @@ -2927,7 +2922,7 @@ void QuicConnection::OnCanWrite() { } void QuicConnection::WriteIfNotBlocked() { - if (donot_write_mid_packet_processing_ && framer().is_processing_packet()) { + if (framer().is_processing_packet()) { QUIC_BUG(connection_write_mid_packet_processing) << ENDPOINT << "Tried to write in mid of packet processing"; return; @@ -2937,78 +2932,108 @@ void QuicConnection::WriteIfNotBlocked() { } } -void QuicConnection::SetServerConnectionId( - const QuicConnectionId& server_connection_id) { - if (use_connection_id_on_default_path_) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_use_connection_id_on_default_path, 2, 2); - default_path_.server_connection_id = server_connection_id; - } else { - server_connection_id_ = server_connection_id; +void QuicConnection::MaybeClearQueuedPacketsOnPathChange() { + if (connection_migration_use_new_cid_ && + peer_issued_cid_manager_ != nullptr && HasQueuedPackets()) { + // Discard packets serialized with the connection ID on the old code path. + // It is possible to clear queued packets only if connection ID changes. + // However, the case where connection ID is unchanged and queued packets are + // non-empty is quite rare. + ClearQueuedPackets(); } } void QuicConnection::ReplaceInitialServerConnectionId( const QuicConnectionId& new_server_connection_id) { QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); - if (support_multiple_connection_ids_) { + if (version().HasIetfQuicFrames()) { if (new_server_connection_id.IsEmpty()) { peer_issued_cid_manager_ = nullptr; } else { if (peer_issued_cid_manager_ != nullptr) { QUIC_BUG_IF(quic_bug_12714_22, !peer_issued_cid_manager_->IsConnectionIdActive( - ServerConnectionId())) + default_path_.server_connection_id)) << "Connection ID replaced header is no longer active. old id: " - << ServerConnectionId() << " new_id: " << new_server_connection_id; - peer_issued_cid_manager_->ReplaceConnectionId(ServerConnectionId(), - new_server_connection_id); + << default_path_.server_connection_id + << " new_id: " << new_server_connection_id; + peer_issued_cid_manager_->ReplaceConnectionId( + default_path_.server_connection_id, new_server_connection_id); } else { peer_issued_cid_manager_ = std::make_unique( kMinNumOfActiveConnectionIds, new_server_connection_id, clock_, - alarm_factory_, this); + alarm_factory_, this, context()); } } } - SetServerConnectionId(new_server_connection_id); - packet_creator_.SetServerConnectionId(ServerConnectionId()); + default_path_.server_connection_id = new_server_connection_id; + packet_creator_.SetServerConnectionId(default_path_.server_connection_id); } -void QuicConnection::FindMatchingClientConnectionIdOrToken( - const PathState& default_path, - const PathState& alternative_path, +void QuicConnection::FindMatchingOrNewClientConnectionIdOrToken( + const PathState& default_path, const PathState& alternative_path, const QuicConnectionId& server_connection_id, QuicConnectionId* client_connection_id, - bool* stateless_reset_token_received, - StatelessResetToken* stateless_reset_token) const { - if (!use_connection_id_on_default_path_) { - return; - } + absl::optional* stateless_reset_token) { QUICHE_DCHECK(perspective_ == Perspective::IS_SERVER); if (peer_issued_cid_manager_ == nullptr || server_connection_id == default_path.server_connection_id) { *client_connection_id = default_path.client_connection_id; - *stateless_reset_token_received = - default_path.stateless_reset_token_received; *stateless_reset_token = default_path.stateless_reset_token; return; } if (server_connection_id == alternative_path_.server_connection_id) { *client_connection_id = alternative_path.client_connection_id; - *stateless_reset_token_received = - alternative_path.stateless_reset_token_received; *stateless_reset_token = alternative_path.stateless_reset_token; return; } - QUIC_BUG(quic_bug_46004) << "Cannot find matching connection ID."; + if (!connection_migration_use_new_cid_) { + QUIC_BUG(quic_bug_46004) << "Cannot find matching connection ID."; + return; + } + auto* connection_id_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + if (connection_id_data == nullptr) { + return; + } + *client_connection_id = connection_id_data->connection_id; + *stateless_reset_token = connection_id_data->stateless_reset_token; +} + +bool QuicConnection::FindOnPathConnectionIds( + const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicConnectionId* client_connection_id, + QuicConnectionId* server_connection_id) const { + if (IsDefaultPath(self_address, peer_address)) { + *client_connection_id = default_path_.client_connection_id, + *server_connection_id = default_path_.server_connection_id; + return true; + } + if (IsAlternativePath(self_address, peer_address)) { + *client_connection_id = alternative_path_.client_connection_id, + *server_connection_id = alternative_path_.server_connection_id; + return true; + } + return false; +} + +void QuicConnection::SetDefaultPathState(PathState new_path_state) { + default_path_ = std::move(new_path_state); + if (connection_migration_use_new_cid_) { + packet_creator_.SetClientConnectionId(default_path_.client_connection_id); + packet_creator_.SetServerConnectionId(default_path_.server_connection_id); + } } bool QuicConnection::ProcessValidatedPacket(const QuicPacketHeader& header) { if (perspective_ == Perspective::IS_CLIENT && version().HasIetfQuicFrames() && direct_peer_address_.IsInitialized() && - last_packet_source_address_.IsInitialized() && - direct_peer_address_ != last_packet_source_address_ && - !visitor_->IsKnownServerAddress(last_packet_source_address_)) { + last_received_packet_info_.source_address.IsInitialized() && + direct_peer_address_ != last_received_packet_info_.source_address && + !visitor_->IsKnownServerAddress( + last_received_packet_info_.source_address)) { // TODO(haoyuewang) Revisit this when preferred_address transport parameter // is used on the client side. // Discard packets received from unseen server addresses. @@ -3017,13 +3042,15 @@ bool QuicConnection::ProcessValidatedPacket(const QuicPacketHeader& header) { if (perspective_ == Perspective::IS_SERVER && default_path_.self_address.IsInitialized() && - last_packet_destination_address_.IsInitialized() && - default_path_.self_address != last_packet_destination_address_) { + last_received_packet_info_.destination_address.IsInitialized() && + default_path_.self_address != + last_received_packet_info_.destination_address) { // Allow change between pure IPv4 and equivalent mapped IPv4 address. if (default_path_.self_address.port() != - last_packet_destination_address_.port() || + last_received_packet_info_.destination_address.port() || default_path_.self_address.host().Normalized() != - last_packet_destination_address_.host().Normalized()) { + last_received_packet_info_.destination_address.host() + .Normalized()) { if (!visitor_->AllowSelfAddressChange()) { CloseConnection( QUIC_ERROR_MIGRATING_ADDRESS, @@ -3032,24 +3059,24 @@ bool QuicConnection::ProcessValidatedPacket(const QuicPacketHeader& header) { return false; } } - default_path_.self_address = last_packet_destination_address_; + default_path_.self_address = last_received_packet_info_.destination_address; } - if (PacketCanReplaceConnectionId(header, perspective_) && - ServerConnectionId() != header.source_connection_id) { + if (PacketCanReplaceServerConnectionId(header, perspective_) && + default_path_.server_connection_id != header.source_connection_id) { QUICHE_DCHECK_EQ(header.long_packet_type, INITIAL); if (server_connection_id_replaced_by_initial_) { QUIC_DLOG(ERROR) << ENDPOINT << "Refusing to replace connection ID " - << ServerConnectionId() << " with " + << default_path_.server_connection_id << " with " << header.source_connection_id; return false; } server_connection_id_replaced_by_initial_ = true; QUIC_DLOG(INFO) << ENDPOINT << "Replacing connection ID " - << ServerConnectionId() << " with " + << default_path_.server_connection_id << " with " << header.source_connection_id; if (!original_destination_connection_id_.has_value()) { - original_destination_connection_id_ = ServerConnectionId(); + original_destination_connection_id_ = default_path_.server_connection_id; } ReplaceInitialServerConnectionId(header.source_connection_id); } @@ -3080,7 +3107,11 @@ bool QuicConnection::ProcessValidatedPacket(const QuicPacketHeader& header) { if (perspective_ == Perspective::IS_SERVER && encryption_level_ == ENCRYPTION_INITIAL && last_size_ > packet_creator_.max_packet_length()) { - SetMaxPacketLength(last_size_); + if (GetQuicFlag(FLAGS_quic_use_lower_server_response_mtu_for_test)) { + SetMaxPacketLength(std::min(last_size_, QuicByteCount(1250))); + } else { + SetMaxPacketLength(last_size_); + } } return true; } @@ -3189,8 +3220,9 @@ bool QuicConnection::ShouldGeneratePacket( QuicVersionUsesCryptoFrames(transport_version())) << ENDPOINT << "Handshake in STREAM frames should not check ShouldGeneratePacket"; - if (support_multiple_connection_ids_ && peer_issued_cid_manager_ != nullptr && + if (peer_issued_cid_manager_ != nullptr && packet_creator_.GetDestinationConnectionId().IsEmpty()) { + QUICHE_DCHECK(version().HasIetfQuicFrames()); QUIC_CODE_COUNT(quic_generate_packet_blocked_by_no_connection_id); QUIC_BUG_IF(quic_bug_90265_1, perspective_ == Perspective::IS_CLIENT); QUIC_DLOG(INFO) << ENDPOINT @@ -3256,6 +3288,20 @@ bool QuicConnection::CanWrite(HasRetransmittableData retransmittable) { return false; } + if (GetQuicReloadableFlag(quic_suppress_write_mid_packet_processing) && + version().CanSendCoalescedPackets() && + framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL) && + framer_.is_processing_packet()) { + QUIC_RELOADABLE_FLAG_COUNT(quic_suppress_write_mid_packet_processing); + // While we still have initial keys, suppress sending in mid of packet + // processing. + // TODO(fayang): always suppress sending while in the mid of packet + // processing. + QUIC_DVLOG(1) << ENDPOINT + << "Suppress sending in the mid of packet processing"; + return false; + } + if (fill_coalesced_packet_) { // Try to coalesce packet, only allow to write when creator is on soft max // packet length. Given the next created packet is going to fill current @@ -3263,7 +3309,12 @@ bool QuicConnection::CanWrite(HasRetransmittableData retransmittable) { return packet_creator_.HasSoftMaxPacketLength(); } - if (LimitedByAmplificationFactor()) { + const bool donot_check_amplification_limit_with_pending_timer_credit = + GetQuicReloadableFlag( + quic_donot_check_amplification_limit_with_pending_timer_credit); + + if (!donot_check_amplification_limit_with_pending_timer_credit && + LimitedByAmplificationFactor()) { // Server is constrained by the amplification restriction. QUIC_CODE_COUNT(quic_throttled_by_amplification_limit); QUIC_DVLOG(1) << ENDPOINT @@ -3277,10 +3328,31 @@ bool QuicConnection::CanWrite(HasRetransmittableData retransmittable) { } if (sent_packet_manager_.pending_timer_transmission_count() > 0) { - // Force sending the retransmissions for HANDSHAKE, TLP, RTO, PROBING cases. + // Allow sending if there are pending tokens, which occurs when: + // 1) firing PTO, + // 2) bundling CRYPTO data with ACKs, + // 3) coalescing CRYPTO data of higher space. return true; } + if (donot_check_amplification_limit_with_pending_timer_credit) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_donot_check_amplification_limit_with_pending_timer_credit); + if (LimitedByAmplificationFactor()) { + // Server is constrained by the amplification restriction. + QUIC_CODE_COUNT(quic_throttled_by_amplification_limit); + QUIC_DVLOG(1) + << ENDPOINT + << "Constrained by amplification restriction to peer address " + << default_path_.peer_address << " bytes received " + << default_path_.bytes_received_before_address_validation + << ", bytes sent" + << default_path_.bytes_sent_before_address_validation; + ++stats_.num_amplification_throttling; + return false; + } + } + if (HandleWriteBlocked()) { return false; } @@ -3478,7 +3550,7 @@ bool QuicConnection::WritePacket(SerializedPacket* packet) { legacy_version_encapsulation_sni_, absl::string_view(packet->encrypted_buffer, packet->encrypted_length), - ServerConnectionId(), framer_.creation_time(), + default_path_.server_connection_id, framer_.creation_time(), GetLimitedMaxPacketSize(long_term_mtu_), const_cast(packet->encrypted_buffer)); if (encapsulated_length != 0) { @@ -3631,7 +3703,8 @@ bool QuicConnection::WritePacket(SerializedPacket* packet) { packet, packet_send_time, packet->transmission_type, IsRetransmittable(*packet), /*measure_rtt=*/send_on_current_path); QUIC_BUG_IF(quic_bug_12714_25, - default_enable_5rto_blackhole_detection_ && + perspective_ == Perspective::IS_SERVER && + default_enable_5rto_blackhole_detection_ && blackhole_detector_.IsDetectionInProgress() && !sent_packet_manager_.HasInFlightPackets()) << ENDPOINT @@ -3668,8 +3741,15 @@ bool QuicConnection::WritePacket(SerializedPacket* packet) { return true; } } - - if (in_flight || !retransmission_alarm_->IsSet()) { + if (GetQuicReloadableFlag( + quic_donot_rearm_pto_on_application_data_during_handshake)) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_donot_rearm_pto_on_application_data_during_handshake); + if (ShouldSetRetransmissionAlarmOnPacketSent(in_flight, + packet->encryption_level)) { + SetRetransmissionAlarm(); + } + } else if (in_flight || !retransmission_alarm_->IsSet()) { SetRetransmissionAlarm(); } SetPingAlarm(); @@ -3987,14 +4067,15 @@ void QuicConnection::OnPathMtuIncreased(QuicPacketLength packet_size) { std::unique_ptr QuicConnection::MakeSelfIssuedConnectionIdManager() { QUICHE_DCHECK((perspective_ == Perspective::IS_CLIENT && - !ClientConnectionId().IsEmpty()) || + !default_path_.client_connection_id.IsEmpty()) || (perspective_ == Perspective::IS_SERVER && - !ServerConnectionId().IsEmpty())); + !default_path_.server_connection_id.IsEmpty())); return std::make_unique( kMinNumOfActiveConnectionIds, - perspective_ == Perspective::IS_CLIENT ? ClientConnectionId() - : ServerConnectionId(), - clock_, alarm_factory_, this); + perspective_ == Perspective::IS_CLIENT + ? default_path_.client_connection_id + : default_path_.server_connection_id, + clock_, alarm_factory_, this, context()); } void QuicConnection::MaybeSendConnectionIdToClient() { @@ -4007,6 +4088,11 @@ void QuicConnection::MaybeSendConnectionIdToClient() { void QuicConnection::OnHandshakeComplete() { sent_packet_manager_.SetHandshakeConfirmed(); + if (connection_migration_use_new_cid_ && + perspective_ == Perspective::IS_SERVER && + self_issued_cid_manager_ != nullptr) { + self_issued_cid_manager_->MaybeSendNewConnectionIds(); + } if (send_ack_frequency_on_handshake_completion_ && sent_packet_manager_.CanSendAckFrequency()) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_can_send_ack_frequency, 2, 3); @@ -4053,9 +4139,7 @@ void QuicConnection::OnPingTimeout() { !visitor_->ShouldKeepConnectionAlive()) { return; } - SendPingAtLevel(use_encryption_level_context_ - ? framer().GetEncryptionLevelToSendApplicationData() - : encryption_level_); + SendPingAtLevel(framer().GetEncryptionLevelToSendApplicationData()); } void QuicConnection::SendAck() { @@ -4122,7 +4206,7 @@ void QuicConnection::OnRetransmissionTimeout() { blackhole_detector_.IsDetectionInProgress()) { // Stop detection in quiescence. QUICHE_DCHECK_EQ(QuicSentPacketManager::LOSS_MODE, retransmission_mode); - blackhole_detector_.StopDetection(); + blackhole_detector_.StopDetection(/*permanent=*/false); } WriteIfNotBlocked(); @@ -4153,15 +4237,28 @@ void QuicConnection::OnRetransmissionTimeout() { << retransmission_mode << ", send PING"; QUICHE_DCHECK_LT(0u, sent_packet_manager_.pending_timer_transmission_count()); - EncryptionLevel level = encryption_level_; - PacketNumberSpace packet_number_space = NUM_PACKET_NUMBER_SPACES; - if (SupportsMultiplePacketNumberSpaces() && - sent_packet_manager_ - .GetEarliestPacketSentTimeForPto(&packet_number_space) - .IsInitialized()) { - level = QuicUtils::GetEncryptionLevel(packet_number_space); + if (SupportsMultiplePacketNumberSpaces()) { + // Based on https://datatracker.ietf.org/doc/html/rfc9002#appendix-A.9 + PacketNumberSpace packet_number_space; + if (sent_packet_manager_ + .GetEarliestPacketSentTimeForPto(&packet_number_space) + .IsInitialized()) { + SendPingAtLevel(QuicUtils::GetEncryptionLevel(packet_number_space)); + } else { + // The client must PTO when there is nothing in flight if the server + // could be blocked from sending by the amplification limit + QUICHE_DCHECK_EQ(Perspective::IS_CLIENT, perspective_); + if (framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_HANDSHAKE)) { + SendPingAtLevel(ENCRYPTION_HANDSHAKE); + } else if (framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_INITIAL)) { + SendPingAtLevel(ENCRYPTION_INITIAL); + } else { + QUIC_BUG(quic_bug_no_pto) << "PTO fired but nothing was sent."; + } + } + } else { + SendPingAtLevel(encryption_level_); } - SendPingAtLevel(level); } if (retransmission_mode == QuicSentPacketManager::PTO_MODE) { sent_packet_manager_.AdjustPendingTimerTransmissions(); @@ -4321,9 +4418,14 @@ void QuicConnection::QueueUndecryptablePacket( } } QUIC_DVLOG(1) << ENDPOINT << "Queueing undecryptable packet."; - undecryptable_packets_.emplace_back(packet, decryption_level); + undecryptable_packets_.emplace_back(packet, decryption_level, + last_received_packet_info_); if (perspective_ == Perspective::IS_CLIENT) { - SetRetransmissionAlarm(); + if (!retransmission_alarm_->IsSet() || + GetRetransmissionDeadline() < retransmission_alarm_->deadline()) { + // Re-arm PTO only if we can make it sooner to speed up recovery. + SetRetransmissionAlarm(); + } } } @@ -4349,7 +4451,19 @@ void QuicConnection::MaybeProcessUndecryptablePackets() { debug_visitor_->OnAttemptingToProcessUndecryptablePacket( undecryptable_packet->encryption_level); } - if (framer_.ProcessPacket(*undecryptable_packet->packet)) { + bool processed = false; + if (reset_per_packet_state_for_undecryptable_packets_) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_reset_per_packet_state_for_undecryptable_packets, 1, 2); + last_received_packet_info_ = undecryptable_packet->packet_info; + last_size_ = undecryptable_packet->packet->length(); + current_packet_data_ = undecryptable_packet->packet->data(); + processed = framer_.ProcessPacket(*undecryptable_packet->packet); + current_packet_data_ = nullptr; + } else { + processed = framer_.ProcessPacket(*undecryptable_packet->packet); + } + if (processed) { QUIC_DVLOG(1) << ENDPOINT << "Processed undecryptable packet!"; iter = undecryptable_packets_.erase(iter); ++stats_.packets_processed; @@ -4370,10 +4484,17 @@ void QuicConnection::MaybeProcessUndecryptablePackets() { iter = undecryptable_packets_.erase(iter); } - // Once forward secure encryption is in use, there will be no - // new keys installed and hence any undecryptable packets will - // never be able to be decrypted. - if (encryption_level_ == ENCRYPTION_FORWARD_SECURE) { + // Once handshake is complete, there will be no new keys installed and hence + // any undecryptable packets will never be able to be decrypted. + bool clear_undecryptable_packets = + encryption_level_ == ENCRYPTION_FORWARD_SECURE; + if (GetQuicReloadableFlag( + quic_clear_undecryptable_packets_on_handshake_complete)) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_clear_undecryptable_packets_on_handshake_complete); + clear_undecryptable_packets = IsHandshakeComplete(); + } + if (clear_undecryptable_packets) { if (debug_visitor_ != nullptr) { for (const auto& undecryptable_packet : undecryptable_packets_) { debug_visitor_->OnUndecryptablePacket( @@ -4383,7 +4504,14 @@ void QuicConnection::MaybeProcessUndecryptablePackets() { undecryptable_packets_.clear(); } if (perspective_ == Perspective::IS_CLIENT) { - SetRetransmissionAlarm(); + if (!retransmission_alarm_->IsSet() || undecryptable_packets_.empty() || + GetRetransmissionDeadline() < retransmission_alarm_->deadline()) { + // 1) If there is still undecryptable packet, only re-arm PTO to make it + // sooner to speed up recovery. + // 2) If all undecryptable packets get processed, re-arm (which may + // postpone) PTO since no immediate recovery is needed. + SetRetransmissionAlarm(); + } } } @@ -4418,11 +4546,7 @@ bool QuicConnection::MaybeProcessCoalescedPackets() { } if (processed) { MaybeProcessUndecryptablePackets(); - if (donot_write_mid_packet_processing_) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_donot_write_mid_packet_processing, 2, - 3); - MaybeSendInResponseToPacket(); - } + MaybeSendInResponseToPacket(); } return processed; } @@ -4471,15 +4595,12 @@ void QuicConnection::SendConnectionClosePacket( const std::string& details) { // Always use the current path to send CONNECTION_CLOSE. QuicPacketCreator::ScopedPeerAddressContext context( - &packet_creator_, peer_address(), /*update_connection_id=*/false); + &packet_creator_, peer_address(), default_path_.client_connection_id, + default_path_.server_connection_id, connection_migration_use_new_cid_); if (!SupportsMultiplePacketNumberSpaces()) { QUIC_DLOG(INFO) << ENDPOINT << "Sending connection close packet."; - if (!use_encryption_level_context_) { - SetDefaultEncryptionLevel(GetConnectionCloseEncryptionLevel()); - } - ScopedEncryptionLevelContext context( - use_encryption_level_context_ ? this : nullptr, - GetConnectionCloseEncryptionLevel()); + ScopedEncryptionLevelContext context(this, + GetConnectionCloseEncryptionLevel()); if (version().CanSendCoalescedPackets()) { coalesced_packet_.Clear(); } @@ -4510,7 +4631,6 @@ void QuicConnection::SendConnectionClosePacket( ClearQueuedPackets(); return; } - const EncryptionLevel current_encryption_level = encryption_level_; ScopedPacketFlusher flusher(this); // Now that the connection is being closed, discard any unsent packets @@ -4528,11 +4648,7 @@ void QuicConnection::SendConnectionClosePacket( } QUIC_DLOG(INFO) << ENDPOINT << "Sending connection close packet at level: " << level; - if (!use_encryption_level_context_) { - SetDefaultEncryptionLevel(level); - } - ScopedEncryptionLevelContext context( - use_encryption_level_context_ ? this : nullptr, level); + ScopedEncryptionLevelContext context(this, level); // Bundle an ACK of the corresponding packet number space for debugging // purpose. bool send_ack = error != QUIC_PACKET_WRITE_ERROR && @@ -4565,9 +4681,6 @@ void QuicConnection::SendConnectionClosePacket( // Since the connection is closing, if the connection close packets were not // sent, then they should be discarded. ClearQueuedPackets(); - if (!use_encryption_level_context_) { - SetDefaultEncryptionLevel(current_encryption_level); - } } void QuicConnection::TearDownLocalConnectionState( @@ -4615,15 +4728,15 @@ void QuicConnection::TearDownLocalConnectionState( void QuicConnection::CancelAllAlarms() { QUIC_DVLOG(1) << "Cancelling all QuicConnection alarms."; - ack_alarm_->Cancel(); - ping_alarm_->Cancel(); - retransmission_alarm_->Cancel(); - send_alarm_->Cancel(); - mtu_discovery_alarm_->Cancel(); - process_undecryptable_packets_alarm_->Cancel(); - discard_previous_one_rtt_keys_alarm_->Cancel(); - discard_zero_rtt_decryption_keys_alarm_->Cancel(); - blackhole_detector_.StopDetection(); + ack_alarm_->PermanentCancel(); + ping_alarm_->PermanentCancel(); + retransmission_alarm_->PermanentCancel(); + send_alarm_->PermanentCancel(); + mtu_discovery_alarm_->PermanentCancel(); + process_undecryptable_packets_alarm_->PermanentCancel(); + discard_previous_one_rtt_keys_alarm_->PermanentCancel(); + discard_zero_rtt_decryption_keys_alarm_->PermanentCancel(); + blackhole_detector_.StopDetection(/*permanent=*/true); idle_network_detector_.StopDetection(); } @@ -4633,6 +4746,7 @@ QuicByteCount QuicConnection::max_packet_length() const { void QuicConnection::SetMaxPacketLength(QuicByteCount length) { long_term_mtu_ = length; + stats_.max_egress_mtu = std::max(stats_.max_egress_mtu, long_term_mtu_); MaybeUpdatePacketCreatorMaxPacketLengthAndPadding(); } @@ -4656,6 +4770,9 @@ void QuicConnection::SetNetworkTimeouts(QuicTime::Delta handshake_timeout, } void QuicConnection::SetPingAlarm() { + if (!connected_) { + return; + } if (perspective_ == Perspective::IS_SERVER && initial_retransmittable_on_wire_timeout_.IsInfinite()) { // The PING alarm exists to support two features: @@ -4999,7 +5116,7 @@ bool QuicConnection::SendGenericPathProbePacket( QUIC_DLOG(INFO) << ENDPOINT << "Sending path probe packet for connection_id = " - << ServerConnectionId(); + << default_path_.server_connection_id; std::unique_ptr probing_packet; if (!version().HasIetfQuicFrames()) { @@ -5044,7 +5161,7 @@ bool QuicConnection::WritePacketUsingWriter( const QuicTime packet_send_time = clock_->Now(); QUIC_DVLOG(2) << ENDPOINT << "Sending path probe packet for server connection ID " - << ServerConnectionId() << std::endl + << default_path_.server_connection_id << std::endl << quiche::QuicheTextUtils::HexDump(absl::string_view( packet->encrypted_buffer, packet->encrypted_length)); WriteResult result = writer->WritePacket( @@ -5179,11 +5296,31 @@ void QuicConnection::StartEffectivePeerMigration(AddressChangeType type) { QUIC_CODE_COUNT_N(quic_server_reverse_validate_new_path3, 3, 6); if (type == NO_CHANGE) { - UpdatePeerAddress(last_packet_source_address_); + UpdatePeerAddress(last_received_packet_info_.source_address); QUIC_BUG(quic_bug_10511_36) << "EffectivePeerMigration started without address change."; return; } + if (GetQuicReloadableFlag( + quic_flush_pending_frames_and_padding_bytes_on_migration)) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_flush_pending_frames_and_padding_bytes_on_migration); + // There could be pending NEW_TOKEN_FRAME triggered by non-probing + // PATH_RESPONSE_FRAME in the same packet or pending padding bytes in the + // packet creator. + packet_creator_.FlushCurrentPacket(); + packet_creator_.SendRemainingPendingPadding(); + if (!connected_) { + return; + } + } else { + if (packet_creator_.HasPendingFrames()) { + packet_creator_.FlushCurrentPacket(); + if (!connected_) { + return; + } + } + } // Action items: // 1. Switch congestion controller; @@ -5204,6 +5341,7 @@ void QuicConnection::StartEffectivePeerMigration(AddressChangeType type) { const QuicSocketAddress previous_direct_peer_address = direct_peer_address_; PathState previous_default_path = std::move(default_path_); active_effective_peer_migration_type_ = type; + MaybeClearQueuedPacketsOnPathChange(); OnConnectionMigration(); // Update congestion controller if the address change type is not PORT_CHANGE. @@ -5238,25 +5376,23 @@ void QuicConnection::StartEffectivePeerMigration(AddressChangeType type) { std::move(alternative_path_.rtt_stats).value()); } } - // Update to the new peer address. - UpdatePeerAddress(last_packet_source_address_); + UpdatePeerAddress(last_received_packet_info_.source_address); // Update the default path. - if (IsAlternativePath(last_packet_destination_address_, + if (IsAlternativePath(last_received_packet_info_.destination_address, current_effective_peer_address)) { - default_path_ = std::move(alternative_path_); + SetDefaultPathState(std::move(alternative_path_)); } else { QuicConnectionId client_connection_id; - bool stateless_reset_token_received = false; - StatelessResetToken stateless_reset_token; - FindMatchingClientConnectionIdOrToken( + absl::optional stateless_reset_token; + FindMatchingOrNewClientConnectionIdOrToken( previous_default_path, alternative_path_, last_packet_destination_connection_id_, &client_connection_id, - &stateless_reset_token_received, &stateless_reset_token); - default_path_ = PathState( - last_packet_destination_address_, current_effective_peer_address, - client_connection_id, last_packet_destination_connection_id_, - stateless_reset_token_received, stateless_reset_token); + &stateless_reset_token); + SetDefaultPathState(PathState( + last_received_packet_info_.destination_address, + current_effective_peer_address, client_connection_id, + last_packet_destination_connection_id_, stateless_reset_token)); // The path is considered validated if its peer IP address matches any // validated path's peer IP address. default_path_.validated = @@ -5265,10 +5401,10 @@ void QuicConnection::StartEffectivePeerMigration(AddressChangeType type) { alternative_path_.validated) || (previous_default_path.validated && type == PORT_CHANGE); } - if (!current_incoming_packet_received_bytes_counted_) { + if (!last_received_packet_info_.received_bytes_counted) { // Increment bytes counting on the new default path. default_path_.bytes_received_before_address_validation += last_size_; - current_incoming_packet_received_bytes_counted_ = true; + last_received_packet_info_.received_bytes_counted = true; } if (!previous_default_path.validated) { @@ -5449,31 +5585,30 @@ bool QuicConnection::UpdatePacketContent(QuicFrameType type) { QuicSocketAddress current_effective_peer_address = GetEffectivePeerAddressFromCurrentPacket(); if (!count_bytes_on_alternative_path_separately_ || - IsDefaultPath(last_packet_destination_address_, - last_packet_source_address_)) { + IsDefaultPath(last_received_packet_info_.destination_address, + last_received_packet_info_.source_address)) { return connected_; } QUIC_CODE_COUNT_N(quic_count_bytes_on_alternative_path_seperately, 3, 5); if (perspective_ == Perspective::IS_SERVER && type == PATH_CHALLENGE_FRAME && - !IsAlternativePath(last_packet_destination_address_, + !IsAlternativePath(last_received_packet_info_.destination_address, current_effective_peer_address)) { QUIC_DVLOG(1) << "The peer is probing a new path with effective peer address " << current_effective_peer_address << ", self address " - << last_packet_destination_address_; + << last_received_packet_info_.destination_address; if (!validate_client_addresses_) { - QuicConnectionId client_connection_id; - bool stateless_reset_token_received = false; - StatelessResetToken stateless_reset_token; - FindMatchingClientConnectionIdOrToken( + QuicConnectionId client_cid; + absl::optional stateless_reset_token; + FindMatchingOrNewClientConnectionIdOrToken( default_path_, alternative_path_, - last_packet_destination_connection_id_, &client_connection_id, - &stateless_reset_token_received, &stateless_reset_token); + last_packet_destination_connection_id_, &client_cid, + &stateless_reset_token); alternative_path_ = PathState( - last_packet_destination_address_, current_effective_peer_address, - client_connection_id, last_packet_destination_connection_id_, - stateless_reset_token_received, stateless_reset_token); + last_received_packet_info_.destination_address, + current_effective_peer_address, client_cid, + last_packet_destination_connection_id_, stateless_reset_token); } else if (!default_path_.validated) { QUIC_CODE_COUNT_N(quic_server_reverse_validate_new_path3, 4, 6); // Skip reverse path validation because either handshake hasn't @@ -5491,36 +5626,19 @@ bool QuicConnection::UpdatePacketContent(QuicFrameType type) { } else if (!IsReceivedPeerAddressValidated()) { QUIC_CODE_COUNT_N(quic_server_reverse_validate_new_path3, 5, 6); QuicConnectionId client_connection_id; - bool stateless_reset_token_received; - StatelessResetToken stateless_reset_token; - FindMatchingClientConnectionIdOrToken( + absl::optional stateless_reset_token; + FindMatchingOrNewClientConnectionIdOrToken( default_path_, alternative_path_, last_packet_destination_connection_id_, &client_connection_id, - &stateless_reset_token_received, &stateless_reset_token); + &stateless_reset_token); // Only override alternative path state upon receiving a PATH_CHALLENGE // from an unvalidated peer address, and the connection isn't validating // a recent peer migration. alternative_path_ = PathState( - last_packet_destination_address_, current_effective_peer_address, - client_connection_id, last_packet_destination_connection_id_, - stateless_reset_token_received, stateless_reset_token); - if (group_path_response_and_challenge_sending_closer_) { - should_proactively_validate_peer_address_on_path_challenge_ = true; - } else { - // Conditions to proactively validate peer address: - // The perspective is server - // The PATH_CHALLENGE is received on an unvalidated alternative path. - // The connection isn't validating migrated peer address, which is of - // higher prority. - QUIC_DVLOG(1) << "Proactively validate the effective peer address " - << current_effective_peer_address; - ValidatePath( - std::make_unique( - default_path_.self_address, current_effective_peer_address, - current_effective_peer_address, this), - std::make_unique( - this, peer_address())); - } + last_received_packet_info_.destination_address, + current_effective_peer_address, client_connection_id, + last_packet_destination_connection_id_, stateless_reset_token); + should_proactively_validate_peer_address_on_path_challenge_ = true; } } MaybeUpdateBytesReceivedFromAlternativeAddress(last_size_); @@ -5560,15 +5678,17 @@ bool QuicConnection::UpdatePacketContent(QuicFrameType type) { << current_effective_peer_migration_type_; } else { is_current_packet_connectivity_probing_ = - (last_packet_source_address_ != peer_address()) || - (last_packet_destination_address_ != default_path_.self_address); + (last_received_packet_info_.source_address != peer_address()) || + (last_received_packet_info_.destination_address != + default_path_.self_address); QUIC_DLOG_IF(INFO, is_current_packet_connectivity_probing_) << ENDPOINT << "Detected connectivity probing packet. " - "last_packet_source_address_:" - << last_packet_source_address_ << ", peer_address_:" << peer_address() - << ", last_packet_destination_address_:" - << last_packet_destination_address_ + "last_packet_source_address:" + << last_received_packet_info_.source_address + << ", peer_address_:" << peer_address() + << ", last_packet_destination_address:" + << last_received_packet_info_.destination_address << ", default path self_address :" << default_path_.self_address; } return connected_; @@ -5577,7 +5697,7 @@ bool QuicConnection::UpdatePacketContent(QuicFrameType type) { current_packet_content_ = NOT_PADDED_PING; if (GetLargestReceivedPacket().IsInitialized() && last_header_.packet_number == GetLargestReceivedPacket()) { - UpdatePeerAddress(last_packet_source_address_); + UpdatePeerAddress(last_received_packet_info_.source_address); if (current_effective_peer_migration_type_ != NO_CHANGE) { // Start effective peer migration immediately when the current packet is // confirmed not a connectivity probing packet. @@ -5590,10 +5710,6 @@ bool QuicConnection::UpdatePacketContent(QuicFrameType type) { void QuicConnection::MaybeStartIetfPeerMigration() { QUICHE_DCHECK(version().HasIetfQuicFrames()); - if (!start_peer_migration_earlier_) { - return; - } - QUIC_CODE_COUNT(quic_start_peer_migration_earlier); if (current_effective_peer_migration_type_ != NO_CHANGE && !IsHandshakeConfirmed()) { QUIC_LOG_EVERY_N_SEC(INFO, 60) @@ -5620,11 +5736,11 @@ void QuicConnection::MaybeStartIetfPeerMigration() { // TODO(fayang): When multiple packet number spaces is supported, only // start peer migration for the application data. if (!validate_client_addresses_) { - UpdatePeerAddress(last_packet_source_address_); + UpdatePeerAddress(last_received_packet_info_.source_address); } StartEffectivePeerMigration(current_effective_peer_migration_type_); } else { - UpdatePeerAddress(last_packet_source_address_); + UpdatePeerAddress(last_received_packet_info_.source_address); } } current_effective_peer_migration_type_ = NO_CHANGE; @@ -5651,7 +5767,7 @@ void QuicConnection::PostProcessAfterAckFrame(bool send_stop_waiting, // In case no new packets get acknowledged, it is possible packets are // detected lost because of time based loss detection. Cancel blackhole // detection if there is no packets in flight. - blackhole_detector_.StopDetection(); + blackhole_detector_.StopDetection(/*permanent=*/false); } if (send_stop_waiting) { @@ -5698,14 +5814,14 @@ void QuicConnection::ResetAckStates() { } MessageStatus QuicConnection::SendMessage(QuicMessageId message_id, - QuicMemSliceSpan message, + absl::Span message, bool flush) { if (!VersionSupportsMessageFrames(transport_version())) { QUIC_BUG(quic_bug_10511_38) << "MESSAGE frame is not supported for version " << transport_version(); return MESSAGE_STATUS_UNSUPPORTED; } - if (message.total_length() > GetCurrentLargestMessagePayload()) { + if (MemSliceSpanTotalSize(message) > GetCurrentLargestMessagePayload()) { return MESSAGE_STATUS_TOO_LARGE; } if (!connected_ || (!flush && !CanWrite(HAS_RETRANSMITTABLE_DATA))) { @@ -5804,8 +5920,6 @@ void QuicConnection::SendAllPendingAcks() { if (!earliest_ack_timeout.IsInitialized()) { return; } - // Latches current encryption level. - const EncryptionLevel current_encryption_level = encryption_level_; for (int8_t i = INITIAL_DATA; i <= APPLICATION_DATA; ++i) { const QuicTime ack_timeout = uber_received_packet_manager_.GetAckTimeout( static_cast(i)); @@ -5825,14 +5939,8 @@ void QuicConnection::SendAllPendingAcks() { QUIC_DVLOG(1) << ENDPOINT << "Sending ACK of packet number space " << PacketNumberSpaceToString( static_cast(i)); - // Switch to the appropriate encryption level. - if (!use_encryption_level_context_) { - SetDefaultEncryptionLevel( - QuicUtils::GetEncryptionLevel(static_cast(i))); - } ScopedEncryptionLevelContext context( - use_encryption_level_context_ ? this : nullptr, - QuicUtils::GetEncryptionLevel(static_cast(i))); + this, QuicUtils::GetEncryptionLevel(static_cast(i))); QuicFrames frames; frames.push_back(uber_received_packet_manager_.GetUpdatedAckFrame( static_cast(i), clock_->ApproximateNow())); @@ -5848,10 +5956,6 @@ void QuicConnection::SendAllPendingAcks() { } ResetAckStates(); } - if (!use_encryption_level_context_) { - // Restores encryption level. - SetDefaultEncryptionLevel(current_encryption_level); - } const QuicTime timeout = uber_received_packet_manager_.GetEarliestAckTimeout(); @@ -5951,6 +6055,9 @@ bool QuicConnection::FlushCoalescedPacket() { if (length == 0) { return false; } + if (debug_visitor_ != nullptr) { + debug_visitor_->OnCoalescedPacketSent(coalesced_packet_, length); + } QUIC_DVLOG(1) << ENDPOINT << "Sending coalesced packet " << coalesced_packet_.ToString(length); @@ -5960,9 +6067,6 @@ bool QuicConnection::FlushCoalescedPacket() { buffered_packets_.emplace_back( buffer, static_cast(length), coalesced_packet_.self_address(), coalesced_packet_.peer_address()); - if (debug_visitor_ != nullptr) { - debug_visitor_->OnCoalescedPacketSent(coalesced_packet_, length); - } return true; } @@ -5983,9 +6087,6 @@ bool QuicConnection::FlushCoalescedPacket() { coalesced_packet_.self_address(), coalesced_packet_.peer_address()); } } - if (debug_visitor_ != nullptr) { - debug_visitor_->OnCoalescedPacketSent(coalesced_packet_, length); - } // Account for added padding. if (length > coalesced_packet_.length()) { size_t padding_size = length - coalesced_packet_.length(); @@ -6043,6 +6144,9 @@ void QuicConnection::SetLargestReceivedPacketWithAck( } void QuicConnection::OnForwardProgressMade() { + if (!connected_) { + return; + } if (is_path_degrading_) { visitor_->OnForwardProgressMadeAfterPathDegrading(); is_path_degrading_ = false; @@ -6054,10 +6158,11 @@ void QuicConnection::OnForwardProgressMade() { GetPathMtuReductionDeadline()); } else { // Stop detections in quiecense. - blackhole_detector_.StopDetection(); + blackhole_detector_.StopDetection(/*permanent=*/false); } QUIC_BUG_IF(quic_bug_12714_35, - default_enable_5rto_blackhole_detection_ && + perspective_ == Perspective::IS_SERVER && + default_enable_5rto_blackhole_detection_ && blackhole_detector_.IsDetectionInProgress() && !sent_packet_manager_.HasInFlightPackets()) << ENDPOINT @@ -6162,21 +6267,16 @@ void QuicConnection::set_client_connection_id( << client_connection_id << " with unsupported version " << version(); return; } - if (use_connection_id_on_default_path_) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_use_connection_id_on_default_path, 1, 2); - default_path_.client_connection_id = client_connection_id; - } else { - client_connection_id_ = client_connection_id; - } + default_path_.client_connection_id = client_connection_id; client_connection_id_is_set_ = true; - if (support_multiple_connection_ids_ && !client_connection_id.IsEmpty()) { + if (version().HasIetfQuicFrames() && !client_connection_id.IsEmpty()) { if (perspective_ == Perspective::IS_SERVER) { QUICHE_DCHECK(peer_issued_cid_manager_ == nullptr); peer_issued_cid_manager_ = std::make_unique( kMinNumOfActiveConnectionIds, client_connection_id, clock_, - alarm_factory_, this); + alarm_factory_, this, context()); } else { // Note in Chromium client, set_client_connection_id is not called and // thus self_issued_cid_manager_ should be null. @@ -6184,11 +6284,12 @@ void QuicConnection::set_client_connection_id( } } QUIC_DLOG(INFO) << ENDPOINT << "setting client connection ID to " - << ClientConnectionId() + << default_path_.client_connection_id << " for connection with server connection ID " - << ServerConnectionId(); - packet_creator_.SetClientConnectionId(ClientConnectionId()); - framer_.SetExpectedClientConnectionIdLength(ClientConnectionId().length()); + << default_path_.server_connection_id; + packet_creator_.SetClientConnectionId(default_path_.client_connection_id); + framer_.SetExpectedClientConnectionIdLength( + default_path_.client_connection_id.length()); } void QuicConnection::OnPathDegradingDetected() { @@ -6237,6 +6338,10 @@ void QuicConnection::OnIdleNetworkDetected() { "No recent network activity after ", duration.ToDebuggingValue(), ". Timeout:", idle_network_detector_.idle_network_timeout().ToDebuggingValue()); + if (perspective() == Perspective::IS_CLIENT && version().UsesTls() && + !IsHandshakeComplete()) { + absl::StrAppend(&error_details, UndecryptablePacketsInfo()); + } QUIC_DVLOG(1) << ENDPOINT << error_details; const bool has_consecutive_pto = sent_packet_manager_.GetConsecutiveTlpCount() > 0 || @@ -6266,9 +6371,10 @@ void QuicConnection::OnIdleNetworkDetected() { void QuicConnection::OnPeerIssuedConnectionIdRetired() { QUICHE_DCHECK(peer_issued_cid_manager_ != nullptr); - QuicConnectionId* default_path_cid = perspective_ == Perspective::IS_CLIENT - ? &ServerConnectionId() - : &ClientConnectionId(); + QuicConnectionId* default_path_cid = + perspective_ == Perspective::IS_CLIENT + ? &default_path_.server_connection_id + : &default_path_.client_connection_id; QuicConnectionId* alternative_path_cid = perspective_ == Perspective::IS_CLIENT ? &alternative_path_.server_connection_id @@ -6280,8 +6386,7 @@ void QuicConnection::OnPeerIssuedConnectionIdRetired() { *default_path_cid = QuicConnectionId(); } // TODO(haoyuewang) Handle the change for default_path_ & alternatvie_path_ - // via the same helper function after use_connection_id_on_default_path_ is - // default true. + // via the same helper function. if (default_path_cid->IsEmpty()) { // Try setting a new connection ID now such that subsequent // RetireConnectionId frames can be sent on the default path. @@ -6289,15 +6394,8 @@ void QuicConnection::OnPeerIssuedConnectionIdRetired() { peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); if (unused_connection_id_data != nullptr) { *default_path_cid = unused_connection_id_data->connection_id; - if (use_connection_id_on_default_path_) { - default_path_.stateless_reset_token = - unused_connection_id_data->stateless_reset_token; - default_path_.stateless_reset_token_received = true; - } else { - received_stateless_reset_token_ = - unused_connection_id_data->stateless_reset_token; - stateless_reset_token_received_ = true; - } + default_path_.stateless_reset_token = + unused_connection_id_data->stateless_reset_token; if (perspective_ == Perspective::IS_CLIENT) { packet_creator_.SetServerConnectionId( unused_connection_id_data->connection_id); @@ -6307,25 +6405,20 @@ void QuicConnection::OnPeerIssuedConnectionIdRetired() { } } } - if (use_connection_id_on_default_path_) { - if (default_path_and_alternative_path_use_the_same_peer_connection_id) { - *alternative_path_cid = *default_path_cid; - alternative_path_.stateless_reset_token_received = - default_path_.stateless_reset_token_received; + if (default_path_and_alternative_path_use_the_same_peer_connection_id) { + *alternative_path_cid = *default_path_cid; + alternative_path_.stateless_reset_token = + default_path_.stateless_reset_token; + } else if (!alternative_path_cid->IsEmpty() && + !peer_issued_cid_manager_->IsConnectionIdActive( + *alternative_path_cid)) { + *alternative_path_cid = EmptyQuicConnectionId(); + const QuicConnectionIdData* unused_connection_id_data = + peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); + if (unused_connection_id_data != nullptr) { + *alternative_path_cid = unused_connection_id_data->connection_id; alternative_path_.stateless_reset_token = - default_path_.stateless_reset_token; - } else if (!alternative_path_cid->IsEmpty() && - !peer_issued_cid_manager_->IsConnectionIdActive( - *alternative_path_cid)) { - *alternative_path_cid = EmptyQuicConnectionId(); - const QuicConnectionIdData* unused_connection_id_data = - peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); - if (unused_connection_id_data != nullptr) { - *alternative_path_cid = unused_connection_id_data->connection_id; - alternative_path_.stateless_reset_token = - unused_connection_id_data->stateless_reset_token; - alternative_path_.stateless_reset_token_received = true; - } + unused_connection_id_data->stateless_reset_token; } } @@ -6333,6 +6426,7 @@ void QuicConnection::OnPeerIssuedConnectionIdRetired() { peer_issued_cid_manager_->ConsumeToBeRetiredConnectionIdSequenceNumbers(); QUICHE_DCHECK(!retired_cid_sequence_numbers.empty()); for (const auto& sequence_number : retired_cid_sequence_numbers) { + ++stats_.num_retire_connection_id_sent; visitor_->SendRetireConnectionId(sequence_number); } } @@ -6340,6 +6434,7 @@ void QuicConnection::OnPeerIssuedConnectionIdRetired() { bool QuicConnection::SendNewConnectionId( const QuicNewConnectionIdFrame& frame) { visitor_->SendNewConnectionId(frame); + ++stats_.num_new_connection_id_sent; return connected_; } @@ -6432,8 +6527,37 @@ bool QuicConnection::SendPathChallenge( const QuicPathFrameBuffer& data_buffer, const QuicSocketAddress& self_address, const QuicSocketAddress& peer_address, - const QuicSocketAddress& /*effective_peer_address*/, + const QuicSocketAddress& effective_peer_address, QuicPacketWriter* writer) { + if (!framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_FORWARD_SECURE)) { + return connected_; + } + if (connection_migration_use_new_cid_) { + { + QuicConnectionId client_cid, server_cid; + FindOnPathConnectionIds(self_address, effective_peer_address, &client_cid, + &server_cid); + QuicPacketCreator::ScopedPeerAddressContext context( + &packet_creator_, peer_address, client_cid, server_cid, + connection_migration_use_new_cid_); + if (writer == writer_) { + ScopedPacketFlusher flusher(this); + // It's on current path, add the PATH_CHALLENGE the same way as other + // frames. This may cause connection to be closed. + packet_creator_.AddPathChallengeFrame(data_buffer); + } else { + std::unique_ptr probing_packet = + packet_creator_.SerializePathChallengeConnectivityProbingPacket( + data_buffer); + QUICHE_DCHECK_EQ(IsRetransmittable(*probing_packet), + NO_RETRANSMITTABLE_DATA); + QUICHE_DCHECK_EQ(self_address, alternative_path_.self_address); + WritePacketUsingWriter(std::move(probing_packet), writer, self_address, + peer_address, /*measure_rtt=*/false); + } + } + return connected_; + } if (writer == writer_) { ScopedPacketFlusher flusher(this); { @@ -6477,7 +6601,6 @@ void QuicConnection::ValidatePath( alternative_path_ = PathState( context->self_address(), context->peer_address(), default_path_.client_connection_id, default_path_.server_connection_id, - default_path_.stateless_reset_token_received, default_path_.stateless_reset_token); } if (path_validator_.HasPendingPathValidation()) { @@ -6503,8 +6626,7 @@ void QuicConnection::ValidatePath( return; } QuicConnectionId client_connection_id, server_connection_id; - StatelessResetToken stateless_reset_token; - bool stateless_reset_token_received = false; + absl::optional stateless_reset_token; if (self_issued_cid_manager_ != nullptr) { client_connection_id = *self_issued_cid_manager_->ConsumeOneConnectionId(); @@ -6513,27 +6635,37 @@ void QuicConnection::ValidatePath( const auto* connection_id_data = peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); server_connection_id = connection_id_data->connection_id; - stateless_reset_token_received = true; stateless_reset_token = connection_id_data->stateless_reset_token; } - alternative_path_ = - PathState(context->self_address(), context->peer_address(), - client_connection_id, server_connection_id, - stateless_reset_token_received, stateless_reset_token); + alternative_path_ = PathState(context->self_address(), + context->peer_address(), client_connection_id, + server_connection_id, stateless_reset_token); } path_validator_.StartPathValidation(std::move(context), std::move(result_delegate)); } -bool QuicConnection::SendPathResponse(const QuicPathFrameBuffer& data_buffer, - QuicSocketAddress peer_address_to_send) { +bool QuicConnection::SendPathResponse( + const QuicPathFrameBuffer& data_buffer, + const QuicSocketAddress& peer_address_to_send, + const QuicSocketAddress& effective_peer_address) { + if (!framer_.HasEncrypterOfEncryptionLevel(ENCRYPTION_FORWARD_SECURE)) { + return false; + } + QuicConnectionId client_cid, server_cid; + if (connection_migration_use_new_cid_) { + FindOnPathConnectionIds(last_received_packet_info_.destination_address, + effective_peer_address, &client_cid, &server_cid); + } // Send PATH_RESPONSE using the provided peer address. If the creator has been // using a different peer address, it will flush before and after serializing // the current PATH_RESPONSE. QuicPacketCreator::ScopedPeerAddressContext context( - &packet_creator_, peer_address_to_send, /*update_connection_id=*/false); + &packet_creator_, peer_address_to_send, client_cid, server_cid, + connection_migration_use_new_cid_); QUIC_DVLOG(1) << ENDPOINT << "Send PATH_RESPONSE to " << peer_address_to_send; - if (default_path_.self_address == last_packet_destination_address_) { + if (default_path_.self_address == + last_received_packet_info_.destination_address) { // The PATH_CHALLENGE is received on the default socket. Respond on the same // socket. return packet_creator_.AddPathResponseFrame(data_buffer); @@ -6544,8 +6676,9 @@ bool QuicConnection::SendPathResponse(const QuicPathFrameBuffer& data_buffer, // used to send PATH_RESPONSE. if (!path_validator_.HasPendingPathValidation() || path_validator_.GetContext()->self_address() != - last_packet_destination_address_) { - // Ignore this PATH_CHALLENGE if it's received from an uninteresting socket. + last_received_packet_info_.destination_address) { + // Ignore this PATH_CHALLENGE if it's received from an uninteresting + // socket. return true; } QuicPacketWriter* writer = path_validator_.GetContext()->WriterToUse(); @@ -6556,12 +6689,13 @@ bool QuicConnection::SendPathResponse(const QuicPathFrameBuffer& data_buffer, QUICHE_DCHECK_EQ(IsRetransmittable(*probing_packet), NO_RETRANSMITTABLE_DATA); QUIC_DVLOG(1) << ENDPOINT << "Send PATH_RESPONSE from alternative socket with address " - << last_packet_destination_address_; + << last_received_packet_info_.destination_address; // Ignore the return value to treat write error on the alternative writer as // part of network error. If the writer becomes blocked, wait for the peer to // send another PATH_CHALLENGE. WritePacketUsingWriter(std::move(probing_packet), writer, - last_packet_destination_address_, peer_address_to_send, + last_received_packet_info_.destination_address, + peer_address_to_send, /*measure_rtt=*/false); return true; } @@ -6597,18 +6731,10 @@ bool QuicConnection::UpdateConnectionIdsOnClientMigration( QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); if (IsAlternativePath(self_address, peer_address)) { // Client migration is after path validation. - if (peer_issued_cid_manager_ != nullptr) { - QUICHE_DCHECK(!default_path_.server_connection_id.IsEmpty()); - packet_creator_.FlushCurrentPacket(); - } default_path_.client_connection_id = alternative_path_.client_connection_id; default_path_.server_connection_id = alternative_path_.server_connection_id; default_path_.stateless_reset_token = alternative_path_.stateless_reset_token; - default_path_.stateless_reset_token_received = - alternative_path_.stateless_reset_token_received; - packet_creator_.SetClientConnectionId(default_path_.client_connection_id); - packet_creator_.SetServerConnectionId(default_path_.server_connection_id); return true; } // Client migration is without path validation. @@ -6632,12 +6758,9 @@ bool QuicConnection::UpdateConnectionIdsOnClientMigration( const auto* connection_id_data = peer_issued_cid_manager_->ConsumeOneUnusedConnectionId(); default_path_.server_connection_id = connection_id_data->connection_id; - default_path_.stateless_reset_token_received = true; default_path_.stateless_reset_token = connection_id_data->stateless_reset_token; } - packet_creator_.SetClientConnectionId(default_path_.client_connection_id); - packet_creator_.SetServerConnectionId(default_path_.server_connection_id); return true; } @@ -6650,21 +6773,39 @@ void QuicConnection::RetirePeerIssuedConnectionIdsNoLongerOnPath() { peer_issued_cid_manager_->MaybeRetireUnusedConnectionIds( {default_path_.server_connection_id, alternative_path_.server_connection_id}); + } else { + peer_issued_cid_manager_->MaybeRetireUnusedConnectionIds( + {default_path_.client_connection_id, + alternative_path_.client_connection_id}); } - // TODO(haoyuewang) Do the same on the server side. } bool QuicConnection::MigratePath(const QuicSocketAddress& self_address, const QuicSocketAddress& peer_address, QuicPacketWriter* writer, bool owns_writer) { + QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); if (!connected_) { + if (owns_writer) { + delete writer; + } return false; } + QUICHE_DCHECK(!version().UsesHttp3() || IsHandshakeConfirmed()); - if (connection_migration_use_new_cid_ && - !UpdateConnectionIdsOnClientMigration(self_address, peer_address)) { - return false; + if (connection_migration_use_new_cid_) { + if (!UpdateConnectionIdsOnClientMigration(self_address, peer_address)) { + if (owns_writer) { + delete writer; + } + return false; + } + if (packet_creator_.GetServerConnectionId().length() != + default_path_.server_connection_id.length()) { + packet_creator_.FlushCurrentPacket(); + } + packet_creator_.SetClientConnectionId(default_path_.client_connection_id); + packet_creator_.SetServerConnectionId(default_path_.server_connection_id); } const auto self_address_change_type = QuicUtils::DetermineAddressChangeType( @@ -6680,6 +6821,7 @@ bool QuicConnection::MigratePath(const QuicSocketAddress& self_address, SetSelfAddress(self_address); UpdatePeerAddress(peer_address); SetQuicPacketWriter(writer, owns_writer); + MaybeClearQueuedPacketsOnPathChange(); OnSuccessfulMigration(is_port_change); return true; } @@ -6689,31 +6831,55 @@ void QuicConnection::OnPathValidationFailureAtClient() { QUICHE_DCHECK(perspective_ == Perspective::IS_CLIENT); alternative_path_.Clear(); } + // The alarm to retire connection IDs no longer on paths is scheduled at the + // end of writing and reading packet. On path validation failure, there could + // be no packet to write or read. Hence the retirement alarm for the + // connection ID associated with the failed path needs to be proactively + // scheduled here. + RetirePeerIssuedConnectionIdsNoLongerOnPath(); +} + +QuicConnectionId QuicConnection::GetOneActiveServerConnectionId() const { + if (perspective_ == Perspective::IS_CLIENT || + self_issued_cid_manager_ == nullptr) { + return connection_id(); + } + auto active_connection_ids = GetActiveServerConnectionIds(); + QUIC_BUG_IF(quic_bug_6944, active_connection_ids.empty()); + if (active_connection_ids.empty() || + std::find(active_connection_ids.begin(), active_connection_ids.end(), + connection_id()) != active_connection_ids.end()) { + return connection_id(); + } + QUICHE_CODE_COUNT(connection_id_on_default_path_has_been_retired); + auto active_connection_id = + self_issued_cid_manager_->GetOneActiveConnectionId(); + return active_connection_id; } std::vector QuicConnection::GetActiveServerConnectionIds() const { - if (!support_multiple_connection_ids_ || - self_issued_cid_manager_ == nullptr) { - return {ServerConnectionId()}; + if (self_issued_cid_manager_ == nullptr) { + return {default_path_.server_connection_id}; } + QUICHE_DCHECK(version().HasIetfQuicFrames()); return self_issued_cid_manager_->GetUnretiredConnectionIds(); } void QuicConnection::CreateConnectionIdManager() { - if (!support_multiple_connection_ids_) { + if (!version().HasIetfQuicFrames()) { return; } if (perspective_ == Perspective::IS_CLIENT) { - if (!ServerConnectionId().IsEmpty()) { + if (!default_path_.server_connection_id.IsEmpty()) { peer_issued_cid_manager_ = std::make_unique( - kMinNumOfActiveConnectionIds, ServerConnectionId(), clock_, - alarm_factory_, this); + kMinNumOfActiveConnectionIds, default_path_.server_connection_id, + clock_, alarm_factory_, this, context()); } } else { - if (!ServerConnectionId().IsEmpty()) { + if (!default_path_.server_connection_id.IsEmpty()) { self_issued_cid_manager_ = MakeSelfIssuedConnectionIdManager(); } } @@ -6769,20 +6935,20 @@ void QuicConnection::MaybeUpdateBytesReceivedFromAlternativeAddress( QuicByteCount received_packet_size) { if (!version().SupportsAntiAmplificationLimit() || perspective_ != Perspective::IS_SERVER || - !IsAlternativePath(last_packet_destination_address_, + !IsAlternativePath(last_received_packet_info_.destination_address, GetEffectivePeerAddressFromCurrentPacket()) || - current_incoming_packet_received_bytes_counted_) { + last_received_packet_info_.received_bytes_counted) { return; } // Only update bytes received if this probing frame is received on the most // recent alternative path. - QUICHE_DCHECK(!IsDefaultPath(last_packet_destination_address_, + QUICHE_DCHECK(!IsDefaultPath(last_received_packet_info_.destination_address, GetEffectivePeerAddressFromCurrentPacket())); if (!alternative_path_.validated) { alternative_path_.bytes_received_before_address_validation += received_packet_size; } - current_incoming_packet_received_bytes_counted_ = true; + last_received_packet_info_.received_bytes_counted = true; } bool QuicConnection::IsDefaultPath( @@ -6804,12 +6970,12 @@ void QuicConnection::PathState::Clear() { peer_address = QuicSocketAddress(); client_connection_id = {}; server_connection_id = {}; - stateless_reset_token_received = false; validated = false; bytes_received_before_address_validation = 0; bytes_sent_before_address_validation = 0; send_algorithm = nullptr; rtt_stats = absl::nullopt; + stateless_reset_token.reset(); } QuicConnection::PathState::PathState(PathState&& other) { @@ -6823,7 +6989,6 @@ QuicConnection::PathState& QuicConnection::PathState::operator=( peer_address = other.peer_address; client_connection_id = other.client_connection_id; server_connection_id = other.server_connection_id; - stateless_reset_token_received = other.stateless_reset_token_received; stateless_reset_token = other.stateless_reset_token; validated = other.validated; bytes_received_before_address_validation = @@ -6859,7 +7024,12 @@ QuicConnection::ReversePathValidationResultDelegate:: const QuicSocketAddress& direct_peer_address) : QuicPathValidator::ResultDelegate(), connection_(connection), - original_direct_peer_address_(direct_peer_address) {} + original_direct_peer_address_(direct_peer_address), + peer_address_default_path_(connection->direct_peer_address_), + peer_address_alternative_path_( + connection_->alternative_path_.peer_address), + active_effective_peer_migration_type_( + connection_->active_effective_peer_migration_type_) {} void QuicConnection::ReversePathValidationResultDelegate:: OnPathValidationSuccess( @@ -6867,10 +7037,36 @@ void QuicConnection::ReversePathValidationResultDelegate:: QUIC_DLOG(INFO) << "Successfully validated new path " << *context; if (connection_->IsDefaultPath(context->self_address(), context->peer_address())) { + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 3, 6); + if (connection_->active_effective_peer_migration_type_ == NO_CHANGE) { + connection_->quic_bug_10511_43_timestamp_ = + connection_->clock_->WallNow(); + connection_->quic_bug_10511_43_error_detail_ = absl::StrCat( + "Reverse path validation on default path from ", + context->self_address().ToString(), " to ", + context->peer_address().ToString(), + " completed without active peer address change: current " + "peer address on default path ", + connection_->direct_peer_address_.ToString(), + ", peer address on default path when the reverse path " + "validation was kicked off ", + peer_address_default_path_.ToString(), + ", peer address on alternative path when the reverse " + "path validation was kicked off ", + peer_address_alternative_path_.ToString(), + ", with active_effective_peer_migration_type_ = ", + AddressChangeTypeToString(active_effective_peer_migration_type_), + ". The last received packet number ", + connection_->last_header_.packet_number.ToString(), + " Connection is connected: ", connection_->connected_); + QUIC_BUG(quic_bug_10511_43) + << connection_->quic_bug_10511_43_error_detail_; + } connection_->OnEffectivePeerMigrationValidated(); } else { QUICHE_DCHECK(connection_->IsAlternativePath( context->self_address(), context->effective_peer_address())); + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 4, 6); QUIC_DVLOG(1) << "Mark alternative peer address " << context->effective_peer_address() << " validated."; connection_->alternative_path_.validated = true; @@ -6887,9 +7083,11 @@ void QuicConnection::ReversePathValidationResultDelegate:: if (connection_->IsDefaultPath(context->self_address(), context->peer_address())) { // Only act upon validation failure on the default path. + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 5, 6); connection_->RestoreToLastValidatedPath(original_direct_peer_address_); } else if (connection_->IsAlternativePath( context->self_address(), context->effective_peer_address())) { + QUIC_CODE_COUNT_N(quic_kick_off_client_address_validation, 6, 6); connection_->alternative_path_.Clear(); } } @@ -6921,6 +7119,7 @@ void QuicConnection::RestoreToLastValidatedPath( ConnectionCloseBehavior::SILENT_CLOSE); return; } + MaybeClearQueuedPacketsOnPathChange(); // Revert congestion control context to old state. OnPeerIpAddressChanged(); @@ -6935,7 +7134,7 @@ void QuicConnection::RestoreToLastValidatedPath( } UpdatePeerAddress(original_direct_peer_address); - default_path_ = std::move(alternative_path_); + SetDefaultPathState(std::move(alternative_path_)); active_effective_peer_migration_type_ = NO_CHANGE; ++stats_.num_invalid_peer_migration; @@ -6957,7 +7156,7 @@ QuicConnection::OnPeerIpAddressChanged() { // re-arm it. SetRetransmissionAlarm(); // Stop detections in quiecense. - blackhole_detector_.StopDetection(); + blackhole_detector_.StopDetection(/*permanent=*/false); return old_send_algorithm; } diff --git a/gquiche/quic/core/quic_connection.h b/gquiche/quic/core/quic_connection.h index 4cf7dea0..7b3e0ce2 100644 --- a/gquiche/quic/core/quic_connection.h +++ b/gquiche/quic/core/quic_connection.h @@ -36,7 +36,7 @@ #include "gquiche/quic/core/quic_alarm.h" #include "gquiche/quic/core/quic_alarm_factory.h" #include "gquiche/quic/core/quic_blocked_writer_interface.h" -#include "gquiche/quic/core/quic_circular_deque.h" +#include "gquiche/quic/core/quic_connection_context.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_connection_id_manager.h" #include "gquiche/quic/core/quic_connection_stats.h" @@ -58,6 +58,7 @@ #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_socket_address.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -234,11 +235,11 @@ class QUIC_EXPORT_PRIVATE QuicConnectionVisitorInterface { // Called by the server to validate |token| in received INITIAL packets. // Consider the client address gets validated (and therefore remove // amplification factor) once the |token| gets successfully validated. - virtual bool ValidateToken(absl::string_view token) const = 0; + virtual bool ValidateToken(absl::string_view token) = 0; // Called by the server to send another token. // Return false if the crypto stream fail to generate one. - virtual void MaybeSendAddressToken() = 0; + virtual bool MaybeSendAddressToken() = 0; // Whether the server address is known to the connection. virtual bool IsKnownServerAddress(const QuicSocketAddress& address) const = 0; @@ -262,10 +263,9 @@ class QUIC_EXPORT_PRIVATE QuicConnectionDebugVisitor const QuicFrames& /*nonretransmittable_frames*/, QuicTime /*sent_time*/) {} - // Called when a coalesced packet has been sent. + // Called when a coalesced packet is successfully serialized. virtual void OnCoalescedPacketSent( - const QuicCoalescedPacket& /*coalesced_packet*/, - size_t /*length*/) {} + const QuicCoalescedPacket& /*coalesced_packet*/, size_t /*length*/) {} // Called when a PING frame has been sent. virtual void OnPingSent() {} @@ -404,6 +404,18 @@ class QUIC_EXPORT_PRIVATE QuicConnectionDebugVisitor // Called when a MaxStreamsFrame has been parsed. virtual void OnMaxStreamsFrame(const QuicMaxStreamsFrame& /*frame*/) {} + // Called when all frames in packet have been parsed. + virtual void OnPacketComplete() {} + + // Called when ack_delay_time in ack frame has been parsed. + virtual void OnAckFrameStart(QuicTime::Delta ack_delay_time) {} + + // Called when ack_range in ack frame has been parsed. + virtual void OnAckRange(QuicPacketNumber start, QuicPacketNumber end) {} + + // Done processing ack frame. + virtual void OnAckFrameEnd(QuicPacketNumber start) {} + // Called when an AckFrequencyFrame has been parsed. virtual void OnAckFrequencyFrame(const QuicAckFrequencyFrame& /*frame*/) {} @@ -777,9 +789,14 @@ class QUIC_EXPORT_PRIVATE QuicConnection const QuicSocketAddress& effective_peer_address() const { return default_path_.peer_address; } - const QuicConnectionId& connection_id() const { return ServerConnectionId(); } + + // Returns the server connection ID used on the default path. + const QuicConnectionId& connection_id() const { + return default_path_.server_connection_id; + } + const QuicConnectionId& client_connection_id() const { - return ClientConnectionId(); + return default_path_.client_connection_id; } void set_client_connection_id(QuicConnectionId client_connection_id); const QuicClock* clock() const { return clock_; } @@ -985,7 +1002,7 @@ class QUIC_EXPORT_PRIVATE QuicConnection // If |flush| is false, this will return a MESSAGE_STATUS_BLOCKED // when the connection is deemed unwritable. virtual MessageStatus SendMessage(QuicMessageId message_id, - QuicMemSliceSpan message, + absl::Span message, bool flush); // Returns the largest payload that will fit into a single MESSAGE frame. @@ -1028,7 +1045,7 @@ class QUIC_EXPORT_PRIVATE QuicConnection } const QuicSocketAddress& last_packet_source_address() const { - return last_packet_source_address_; + return last_received_packet_info_.source_address; } bool fill_up_link_during_probing() const { @@ -1146,7 +1163,7 @@ class QUIC_EXPORT_PRIVATE QuicConnection // Returns true if ack_alarm_ is set. bool HasPendingAcks() const; - void OnUserAgentIdKnown() { sent_packet_manager_.OnUserAgentIdKnown(); } + virtual void OnUserAgentIdKnown(const std::string& user_agent_id); // Enables Legacy Version Encapsulation using |server_name| as SNI. // Can only be set if this is a client connection. @@ -1199,12 +1216,6 @@ class QUIC_EXPORT_PRIVATE QuicConnection bool is_processing_packet() const { return framer_.is_processing_packet(); } - bool encrypted_control_frames() const { return encrypted_control_frames_; } - - bool use_encryption_level_context() const { - return use_encryption_level_context_; - } - bool HasPendingPathValidation() const; QuicPathValidationContext* GetPathValidationContext() const; @@ -1225,32 +1236,47 @@ class QUIC_EXPORT_PRIVATE QuicConnection void SetSourceAddressTokenToSend(absl::string_view token); void SendPing() { - SendPingAtLevel(use_encryption_level_context_ - ? framer().GetEncryptionLevelToSendApplicationData() - : encryption_level_); + SendPingAtLevel(framer().GetEncryptionLevelToSendApplicationData()); } + // Returns one server connection ID that associates the current session in the + // session map. + virtual QuicConnectionId GetOneActiveServerConnectionId() const; + + // Returns all server connection IDs that have not been removed from the + // session map. virtual std::vector GetActiveServerConnectionIds() const; bool validate_client_address() const { return validate_client_addresses_; } - bool support_multiple_connection_ids() const { - return support_multiple_connection_ids_; - } - - bool use_connection_id_on_default_path() const { - return use_connection_id_on_default_path_; - } - bool connection_migration_use_new_cid() const { return connection_migration_use_new_cid_; } + bool count_bytes_on_alternative_path_separately() const { + return count_bytes_on_alternative_path_separately_; + } + // Instantiates connection ID manager. void CreateConnectionIdManager(); - bool donot_write_mid_packet_processing() const { - return donot_write_mid_packet_processing_; + QuicConnectionContext* context() { return &context_; } + const QuicConnectionContext* context() const { return &context_; } + + void set_tracer(std::unique_ptr tracer) { + context_.tracer.swap(tracer); + } + + void set_bug_listener(std::unique_ptr bug_listener) { + context_.bug_listener.swap(bug_listener); + } + + absl::optional quic_bug_10511_43_timestamp() const { + return quic_bug_10511_43_timestamp_; + } + + const std::string& quic_bug_10511_43_error_detail() const { + return quic_bug_10511_43_error_detail_; } protected: @@ -1340,6 +1366,10 @@ class QUIC_EXPORT_PRIVATE QuicConnection default_path_.bytes_received_before_address_validation += length; } + void set_validate_client_addresses(bool value) { + validate_client_addresses_ = value; + } + private: friend class test::QuicConnectionPeer; @@ -1355,14 +1385,12 @@ class QUIC_EXPORT_PRIVATE QuicConnection const QuicSocketAddress& alternative_peer_address, const QuicConnectionId& client_connection_id, const QuicConnectionId& server_connection_id, - bool stateless_reset_token_received, - StatelessResetToken stateless_reset_token) + absl::optional stateless_reset_token) : self_address(alternative_self_address), peer_address(alternative_peer_address), client_connection_id(client_connection_id), server_connection_id(server_connection_id), - stateless_reset_token(stateless_reset_token), - stateless_reset_token_received(stateless_reset_token_received) {} + stateless_reset_token(stateless_reset_token) {} PathState(PathState&& other); @@ -1376,8 +1404,7 @@ class QUIC_EXPORT_PRIVATE QuicConnection QuicSocketAddress peer_address; QuicConnectionId client_connection_id; QuicConnectionId server_connection_id; - StatelessResetToken stateless_reset_token; - bool stateless_reset_token_received = false; + absl::optional stateless_reset_token; // True if the peer address has been validated. Address is considered // validated when 1) an address token of the peer address is received and // validated, or 2) a HANDSHAKE packet has been successfully processed on @@ -1419,15 +1446,38 @@ class QUIC_EXPORT_PRIVATE QuicConnection const QuicSocketAddress peer_address; }; - // UndecrytablePacket comprises a undecryptable packet and the its encryption - // level. + // ReceivedPacketInfo comprises the received packet information, which can be + // retrieved before the packet gets successfully decrypted. + struct QUIC_EXPORT_PRIVATE ReceivedPacketInfo { + explicit ReceivedPacketInfo(QuicTime receipt_time) + : received_bytes_counted(false), receipt_time(receipt_time) {} + ReceivedPacketInfo(const QuicSocketAddress& destination_address, + const QuicSocketAddress& source_address, + QuicTime receipt_time) + : received_bytes_counted(false), + destination_address(destination_address), + source_address(source_address), + receipt_time(receipt_time) {} + + bool received_bytes_counted; + QuicSocketAddress destination_address; + QuicSocketAddress source_address; + QuicTime receipt_time; + }; + + // UndecrytablePacket comprises a undecryptable packet and related + // information. struct QUIC_EXPORT_PRIVATE UndecryptablePacket { UndecryptablePacket(const QuicEncryptedPacket& packet, - EncryptionLevel encryption_level) - : packet(packet.Clone()), encryption_level(encryption_level) {} + EncryptionLevel encryption_level, + const ReceivedPacketInfo& packet_info) + : packet(packet.Clone()), + encryption_level(encryption_level), + packet_info(packet_info) {} std::unique_ptr packet; EncryptionLevel encryption_level; + ReceivedPacketInfo packet_info; }; // Handles the reverse path validation result depending on connection state: @@ -1449,6 +1499,11 @@ class QUIC_EXPORT_PRIVATE QuicConnection private: QuicConnection* connection_; QuicSocketAddress original_direct_peer_address_; + // TODO(b/205023946) Debug-only fields, to be deprecated after the bug is + // fixed. + QuicSocketAddress peer_address_default_path_; + QuicSocketAddress peer_address_alternative_path_; + AddressChangeType active_effective_peer_migration_type_; }; // A class which sets and clears in_on_retransmission_time_out_ when entering @@ -1464,27 +1519,9 @@ class QUIC_EXPORT_PRIVATE QuicConnection QuicConnection* connection_; // Not owned. }; - QuicConnectionId& ClientConnectionId() { - return use_connection_id_on_default_path_ - ? default_path_.client_connection_id - : client_connection_id_; - } - const QuicConnectionId& ClientConnectionId() const { - return use_connection_id_on_default_path_ - ? default_path_.client_connection_id - : client_connection_id_; - } - QuicConnectionId& ServerConnectionId() { - return use_connection_id_on_default_path_ - ? default_path_.server_connection_id - : server_connection_id_; - } - const QuicConnectionId& ServerConnectionId() const { - return use_connection_id_on_default_path_ - ? default_path_.server_connection_id - : server_connection_id_; - } - void SetServerConnectionId(const QuicConnectionId& server_connection_id); + // If peer uses non-empty connection ID, discards any buffered packets on path + // change in IETF QUIC. + void MaybeClearQueuedPacketsOnPathChange(); // Notifies the visitor of the close and marks the connection as disconnected. // Does not send a connection close frame to the peer. It should only be @@ -1506,14 +1543,28 @@ class QUIC_EXPORT_PRIVATE QuicConnection const QuicConnectionId& new_server_connection_id); // Given the server_connection_id find if there is already a corresponding - // client connection ID used on default/alternative path. - void FindMatchingClientConnectionIdOrToken( - const PathState& default_path, - const PathState& alternative_path, + // client connection ID used on default/alternative path. If not, find if + // there is an unused connection ID. + void FindMatchingOrNewClientConnectionIdOrToken( + const PathState& default_path, const PathState& alternative_path, const QuicConnectionId& server_connection_id, QuicConnectionId* client_connection_id, - bool* stateless_reset_token_received, - StatelessResetToken* stateless_reset_token) const; + absl::optional* stateless_reset_token); + + // Returns true and sets connection IDs if (self_address, peer_address) + // corresponds to either the default path or alternative path. Returns false + // otherwise. + bool FindOnPathConnectionIds(const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, + QuicConnectionId* client_connection_id, + QuicConnectionId* server_connection_id) const; + + // Set default_path_ to the new_path_state and update the connection IDs in + // packet creator accordingly. + void SetDefaultPathState(PathState new_path_state); + + // Returns true if header contains valid server connection ID. + bool ValidateServerConnectionId(const QuicPacketHeader& header) const; // Update the connection IDs when client migrates with/without validation. // Returns false if required connection ID is not available. @@ -1684,9 +1735,6 @@ class QUIC_EXPORT_PRIVATE QuicConnection // Returns the largest sent packet number that has been ACKed by peer. QuicPacketNumber GetLargestAckedPacket() const; - // Whether incoming_connection_ids_ contains connection_id. - bool HasIncomingConnectionId(QuicConnectionId connection_id); - // Whether connection is limited by amplification factor. bool LimitedByAmplificationFactor() const; @@ -1758,7 +1806,8 @@ class QUIC_EXPORT_PRIVATE QuicConnection // Send PATH_RESPONSE to the given peer address. bool SendPathResponse(const QuicPathFrameBuffer& data_buffer, - QuicSocketAddress peer_address_to_send); + const QuicSocketAddress& peer_address_to_send, + const QuicSocketAddress& effective_peer_address); // Update both connection's and packet creator's peer address. void UpdatePeerAddress(QuicSocketAddress peer_address); @@ -1825,6 +1874,13 @@ class QUIC_EXPORT_PRIVATE QuicConnection // when a new client connection ID is received. void OnClientConnectionIdAvailable(); + // Returns true if connection needs to set retransmission alarm after a packet + // gets sent. + bool ShouldSetRetransmissionAlarmOnPacketSent(bool in_flight, + EncryptionLevel level) const; + + QuicConnectionContext context_; + QuicFramer framer_; // Contents received in the current packet, especially used to identify @@ -1852,8 +1908,6 @@ class QUIC_EXPORT_PRIVATE QuicConnection const QuicClock* clock_; QuicRandom* random_generator_; - QuicConnectionId server_connection_id_; - QuicConnectionId client_connection_id_; // On the server, the connection ID is set when receiving the first packet. // This variable ensures we only set it this way once. bool client_connection_id_is_set_; @@ -1920,7 +1974,7 @@ class QUIC_EXPORT_PRIVATE QuicConnection // Collection of coalesced packets which were received while processing // the current packet. - QuicCircularDeque> + quiche::QuicheCircularDeque> received_coalesced_packets_; // Maximum number of undecryptable packets the connection will store. @@ -2001,10 +2055,9 @@ class QUIC_EXPORT_PRIVATE QuicConnection QuicPacketCreator packet_creator_; - // The time that a packet is received for this connection. Initialized to - // connection creation time. - // This does not indicate the packet was processed. - QuicTime time_of_last_received_packet_; + // Information about the last received QUIC packet, which may not have been + // successfully decrypted and processed. + ReceivedPacketInfo last_received_packet_info_; // Sent packet manager which tracks the status of packets sent by this // connection and contains the send and receive algorithms to determine when @@ -2022,13 +2075,10 @@ class QUIC_EXPORT_PRIVATE QuicConnection // close. bool connected_; - // Destination address of the last received packet. - QuicSocketAddress last_packet_destination_address_; - - // Source address of the last received packet. - QuicSocketAddress last_packet_source_address_; - - // Destination connection id of the last received packet. + // Destination connection ID of the last received packet. If this ID is the + // original server connection ID chosen by client and server replaces it with + // a different ID, last_packet_destination_connection_id_ is set to the + // replacement connection ID on the server side. QuicConnectionId last_packet_destination_connection_id_; // Set to false if the connection should not send truncated connection IDs to @@ -2094,12 +2144,6 @@ class QUIC_EXPORT_PRIVATE QuicConnection // retransmission code. bool probing_retransmission_pending_; - // Indicates whether a stateless reset token has been received from peer. - bool stateless_reset_token_received_; - // Stores received stateless reset token from peer. Used to verify whether a - // packet is a stateless reset packet. - StatelessResetToken received_stateless_reset_token_; - // Id of latest sent control frame. 0 if no control frame has been sent. QuicControlFrameId last_control_frame_id_; @@ -2130,18 +2174,15 @@ class QUIC_EXPORT_PRIVATE QuicConnection // saved and responded to. // TODO(danzh) deprecate this field when deprecating // --quic_send_path_response. - QuicCircularDeque received_path_challenge_payloads_; + quiche::QuicheCircularDeque + received_path_challenge_payloads_; // Buffer outstanding PATH_CHALLENGEs if socket write is blocked, future // OnCanWrite will attempt to respond with PATH_RESPONSEs using the retained // payload and peer addresses. // TODO(fayang): remove this when deprecating quic_drop_unsent_path_response. - QuicCircularDeque pending_path_challenge_payloads_; - - // Set of connection IDs that should be accepted as destination on - // received packets. This is conceptually a set but is implemented as a - // vector to improve performance since it is expected to be very small. - std::vector incoming_connection_ids_; + quiche::QuicheCircularDeque + pending_path_challenge_payloads_; // When we receive a RETRY packet or some INITIAL packets, we replace // |server_connection_id_| with the value from that packet and save off the @@ -2149,6 +2190,9 @@ class QUIC_EXPORT_PRIVATE QuicConnection // |original_destination_connection_id_| for validation. absl::optional original_destination_connection_id_; + // The connection ID that replaces original_destination_connection_id_. + QuicConnectionId original_destination_connection_id_replacement_; + // After we receive a RETRY packet, |retry_source_connection_id_| contains // the source connection ID from that packet. absl::optional retry_source_connection_id_; @@ -2189,13 +2233,8 @@ class QUIC_EXPORT_PRIVATE QuicConnection size_t anti_amplification_factor_ = GetQuicFlag(FLAGS_quic_anti_amplification_factor); - bool start_peer_migration_earlier_ = - GetQuicReloadableFlag(quic_start_peer_migration_earlier); - - // latch --gfe2_reloadable_flag_quic_send_path_response and - // --gfe2_reloadable_flag_quic_start_peer_migration_earlier. - bool send_path_response_ = start_peer_migration_earlier_ && - GetQuicReloadableFlag(quic_send_path_response2); + // latch --gfe2_reloadable_flag_quic_send_path_response. + bool send_path_response_ = GetQuicReloadableFlag(quic_send_path_response2); bool use_path_validator_ = send_path_response_ && @@ -2228,10 +2267,6 @@ class QUIC_EXPORT_PRIVATE QuicConnection // True if we are currently processing OnRetransmissionTimeout. bool in_on_retransmission_time_out_ = false; - const bool encrypted_control_frames_; - - const bool use_encryption_level_context_; - QuicPathValidator path_validator_; // Stores information of a path which maybe used as default path in the @@ -2249,22 +2284,12 @@ class QUIC_EXPORT_PRIVATE QuicConnection // This field is used to debug b/177312785. QuicFrameType most_recent_frame_type_; - bool current_incoming_packet_received_bytes_counted_ = false; - bool count_bytes_on_alternative_path_separately_ = GetQuicReloadableFlag(quic_count_bytes_on_alternative_path_seperately); // If true, upon seeing a new client address, validate the client address. bool validate_client_addresses_ = false; - bool support_multiple_connection_ids_ = false; - - const bool donot_write_mid_packet_processing_ = - GetQuicReloadableFlag(quic_donot_write_mid_packet_processing); - - bool use_connection_id_on_default_path_ = - GetQuicReloadableFlag(quic_use_connection_id_on_default_path); - // Indicates whether we should proactively validate peer address on a // PATH_CHALLENGE received. bool should_proactively_validate_peer_address_on_path_challenge_ = false; @@ -2272,12 +2297,14 @@ class QUIC_EXPORT_PRIVATE QuicConnection // Enable this via reloadable flag once this feature is complete. bool connection_migration_use_new_cid_ = false; - const bool group_path_response_and_challenge_sending_closer_ = + const bool reset_per_packet_state_for_undecryptable_packets_ = GetQuicReloadableFlag( - quic_group_path_response_and_challenge_sending_closer); + quic_reset_per_packet_state_for_undecryptable_packets); - const bool quic_deprecate_incoming_connection_ids_ = - GetQuicReloadableFlag(quic_deprecate_incoming_connection_ids); + // TODO(b/205023946) Debug-only fields, to be deprecated after the bug is + // fixed. + absl::optional quic_bug_10511_43_timestamp_; + std::string quic_bug_10511_43_error_detail_; }; } // namespace quic diff --git a/gquiche/quic/core/quic_connection_context.cc b/gquiche/quic/core/quic_connection_context.cc new file mode 100644 index 00000000..b222f0fd --- /dev/null +++ b/gquiche/quic/core/quic_connection_context.cc @@ -0,0 +1,36 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/quic_connection_context.h" + +#include "gquiche/common/platform/api/quiche_thread_local.h" + +namespace quic { +namespace { +DEFINE_QUICHE_THREAD_LOCAL_POINTER(CurrentContext, QuicConnectionContext); +} // namespace + +// static +QuicConnectionContext* QuicConnectionContext::Current() { + return GET_QUICHE_THREAD_LOCAL_POINTER(CurrentContext); +} + +QuicConnectionContextSwitcher::QuicConnectionContextSwitcher( + QuicConnectionContext* new_context) + : old_context_(QuicConnectionContext::Current()) { + SET_QUICHE_THREAD_LOCAL_POINTER(CurrentContext, new_context); + if (new_context && new_context->tracer) { + new_context->tracer->Activate(); + } +} + +QuicConnectionContextSwitcher::~QuicConnectionContextSwitcher() { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->Deactivate(); + } + SET_QUICHE_THREAD_LOCAL_POINTER(CurrentContext, old_context_); +} + +} // namespace quic diff --git a/gquiche/quic/core/quic_connection_context.h b/gquiche/quic/core/quic_connection_context.h new file mode 100644 index 00000000..42a79e9f --- /dev/null +++ b/gquiche/quic/core/quic_connection_context.h @@ -0,0 +1,133 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_CORE_QUIC_CONNECTION_CONTEXT_H_ +#define QUICHE_QUIC_CORE_QUIC_CONNECTION_CONTEXT_H_ + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/common/platform/api/quiche_logging.h" + +namespace quic { + +// QuicConnectionTracer is responsible for emit trace messages for a single +// QuicConnection. +// QuicConnectionTracer is part of the QuicConnectionContext. +class QUIC_EXPORT_PRIVATE QuicConnectionTracer { + public: + virtual ~QuicConnectionTracer() = default; + + // Emit a trace message from a string literal. The trace may simply remember + // the address of the literal in this function and read it at a later time. + virtual void PrintLiteral(const char* literal) = 0; + + // Emit a trace message from a string_view. Unlike PrintLiteral, this function + // will not read |s| after it returns. + virtual void PrintString(absl::string_view s) = 0; + + // Emit a trace message from printf-style arguments. + template + void Printf(const absl::FormatSpec& format, const Args&... args) { + std::string s = absl::StrFormat(format, args...); + PrintString(s); + } + + private: + friend class QuicConnectionContextSwitcher; + + // Called by QuicConnectionContextSwitcher, when |this| becomes the current + // thread's QUIC connection tracer. + // + // Activate/Deactivate are only called by QuicConnectionContextSwitcher's + // constructor/destructor, they always come in pairs. + virtual void Activate() {} + + // Called by QuicConnectionContextSwitcher, when |this| stops from being the + // current thread's QUIC connection tracer. + // + // Activate/Deactivate are only called by QuicConnectionContextSwitcher's + // constructor/destructor, they always come in pairs. + virtual void Deactivate() {} +}; + +// QuicBugListener is a helper class for implementing QUIC_BUG. The QUIC_BUG +// implementation can send the bug information into quic::CurrentBugListener(). +class QUIC_EXPORT_PRIVATE QuicBugListener { + public: + virtual ~QuicBugListener() = default; + virtual void OnQuicBug(const char* bug_id, const char* file, int line, + absl::string_view bug_message) = 0; +}; + +// QuicConnectionContext is a per-QuicConnection context that includes +// facilities useable by any part of a QuicConnection. A QuicConnectionContext +// is owned by a QuicConnection. +// +// The 'top-level' QuicConnection functions are responsible for maintaining the +// thread-local QuicConnectionContext pointer, such that any function called by +// them(directly or indirectly) can access the context. +// +// Like QuicConnection, all facilities in QuicConnectionContext are assumed to +// be called from a single thread at a time, they are NOT thread-safe. +struct QUIC_EXPORT_PRIVATE QuicConnectionContext final { + // Get the context on the current executing thread. nullptr if the current + // function is not called from a 'top-level' QuicConnection function. + static QuicConnectionContext* Current(); + + std::unique_ptr tracer; + std::unique_ptr bug_listener; +}; + +// QuicConnectionContextSwitcher is a RAII object used for maintaining the +// thread-local QuicConnectionContext pointer. +class QUIC_EXPORT_PRIVATE QuicConnectionContextSwitcher final { + public: + // The constructor switches from QuicConnectionContext::Current() to + // |new_context|. + explicit QuicConnectionContextSwitcher(QuicConnectionContext* new_context); + + // The destructor switches from QuicConnectionContext::Current() back to the + // old context. + ~QuicConnectionContextSwitcher(); + + private: + QuicConnectionContext* old_context_; +}; + +// Emit a trace message from a string literal to the current tracer(if any). +inline void QUIC_TRACELITERAL(const char* literal) { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->PrintLiteral(literal); + } +} + +// Emit a trace message from a string_view to the current tracer(if any). +inline void QUIC_TRACESTRING(absl::string_view s) { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->PrintString(s); + } +} + +// Emit a trace message from printf-style arguments to the current tracer(if +// any). +template +void QUIC_TRACEPRINTF(const absl::FormatSpec& format, + const Args&... args) { + QuicConnectionContext* current = QuicConnectionContext::Current(); + if (current && current->tracer) { + current->tracer->Printf(format, args...); + } +} + +inline QuicBugListener* CurrentBugListener() { + QuicConnectionContext* current = QuicConnectionContext::Current(); + return (current != nullptr) ? current->bug_listener.get() : nullptr; +} + +} // namespace quic + +#endif // QUICHE_QUIC_CORE_QUIC_CONNECTION_CONTEXT_H_ diff --git a/gquiche/quic/core/quic_connection_context_test.cc b/gquiche/quic/core/quic_connection_context_test.cc new file mode 100644 index 00000000..86c2ed2b --- /dev/null +++ b/gquiche/quic/core/quic_connection_context_test.cc @@ -0,0 +1,173 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/core/quic_connection_context.h" + +#include "gquiche/quic/platform/api/quic_test.h" +#include "gquiche/quic/platform/api/quic_thread.h" + +using testing::ElementsAre; + +namespace quic { +namespace { + +class TraceCollector : public QuicConnectionTracer { + public: + ~TraceCollector() override = default; + + void PrintLiteral(const char* literal) override { trace_.push_back(literal); } + + void PrintString(absl::string_view s) override { + trace_.push_back(std::string(s)); + } + + const std::vector& trace() const { return trace_; } + + private: + std::vector trace_; +}; + +struct FakeConnection { + FakeConnection() { context.tracer = std::make_unique(); } + + const std::vector& trace() const { + return static_cast(context.tracer.get())->trace(); + } + + QuicConnectionContext context; +}; + +void SimpleSwitch() { + FakeConnection connection; + + // These should be ignored since current context is nullptr. + EXPECT_EQ(QuicConnectionContext::Current(), nullptr); + QUIC_TRACELITERAL("before switch: literal"); + QUIC_TRACESTRING(std::string("before switch: string")); + QUIC_TRACEPRINTF("%s: %s", "before switch", "printf"); + + { + QuicConnectionContextSwitcher switcher(&connection.context); + QUIC_TRACELITERAL("literal"); + QUIC_TRACESTRING(std::string("string")); + QUIC_TRACEPRINTF("%s", "printf"); + } + + EXPECT_EQ(QuicConnectionContext::Current(), nullptr); + QUIC_TRACELITERAL("after switch: literal"); + QUIC_TRACESTRING(std::string("after switch: string")); + QUIC_TRACEPRINTF("%s: %s", "after switch", "printf"); + + EXPECT_THAT(connection.trace(), ElementsAre("literal", "string", "printf")); +} + +void NestedSwitch() { + FakeConnection outer, inner; + + { + QuicConnectionContextSwitcher switcher(&outer.context); + QUIC_TRACELITERAL("outer literal 0"); + QUIC_TRACESTRING(std::string("outer string 0")); + QUIC_TRACEPRINTF("%s %s %d", "outer", "printf", 0); + + { + QuicConnectionContextSwitcher switcher(&inner.context); + QUIC_TRACELITERAL("inner literal"); + QUIC_TRACESTRING(std::string("inner string")); + QUIC_TRACEPRINTF("%s %s", "inner", "printf"); + } + + QUIC_TRACELITERAL("outer literal 1"); + QUIC_TRACESTRING(std::string("outer string 1")); + QUIC_TRACEPRINTF("%s %s %d", "outer", "printf", 1); + } + + EXPECT_THAT(outer.trace(), ElementsAre("outer literal 0", "outer string 0", + "outer printf 0", "outer literal 1", + "outer string 1", "outer printf 1")); + + EXPECT_THAT(inner.trace(), + ElementsAre("inner literal", "inner string", "inner printf")); +} + +void AlternatingSwitch() { + FakeConnection zero, one, two; + for (int i = 0; i < 15; ++i) { + FakeConnection* connection = + ((i % 3) == 0) ? &zero : (((i % 3) == 1) ? &one : &two); + + QuicConnectionContextSwitcher switcher(&connection->context); + QUIC_TRACEPRINTF("%d", i); + } + + EXPECT_THAT(zero.trace(), ElementsAre("0", "3", "6", "9", "12")); + EXPECT_THAT(one.trace(), ElementsAre("1", "4", "7", "10", "13")); + EXPECT_THAT(two.trace(), ElementsAre("2", "5", "8", "11", "14")); +} + +typedef void (*ThreadFunction)(); + +template +class TestThread : public QuicThread { + public: + TestThread() : QuicThread("TestThread") {} + ~TestThread() override = default; + + protected: + void Run() override { func(); } +}; + +template +void RunInThreads(size_t n_threads) { + using ThreadType = TestThread; + std::vector threads(n_threads); + + for (ThreadType& t : threads) { + t.Start(); + } + + for (ThreadType& t : threads) { + t.Join(); + } +} + +class QuicConnectionContextTest : public QuicTest { + protected: +}; + +TEST_F(QuicConnectionContextTest, NullTracerOK) { + FakeConnection connection; + std::unique_ptr tracer; + + { + QuicConnectionContextSwitcher switcher(&connection.context); + QUIC_TRACELITERAL("msg 1 recorded"); + } + + connection.context.tracer.swap(tracer); + + { + QuicConnectionContextSwitcher switcher(&connection.context); + // Should be a no-op since connection.context.tracer is nullptr. + QUIC_TRACELITERAL("msg 2 ignored"); + } + + EXPECT_THAT(static_cast(tracer.get())->trace(), + ElementsAre("msg 1 recorded")); +} + +TEST_F(QuicConnectionContextTest, TestSimpleSwitch) { + RunInThreads(10); +} + +TEST_F(QuicConnectionContextTest, TestNestedSwitch) { + RunInThreads(10); +} + +TEST_F(QuicConnectionContextTest, TestAlternatingSwitch) { + RunInThreads(10); +} + +} // namespace +} // namespace quic diff --git a/gquiche/quic/core/quic_connection_id.cc b/gquiche/quic/core/quic_connection_id.cc index 6fd24212..94879e9c 100644 --- a/gquiche/quic/core/quic_connection_id.cc +++ b/gquiche/quic/core/quic_connection_id.cc @@ -18,7 +18,6 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/quiche_endian.h" namespace quic { diff --git a/gquiche/quic/core/quic_connection_id_manager.cc b/gquiche/quic/core/quic_connection_id_manager.cc index c4f64db6..ed0dcb39 100644 --- a/gquiche/quic/core/quic_connection_id_manager.cc +++ b/gquiche/quic/core/quic_connection_id_manager.cc @@ -3,12 +3,14 @@ // found in the LICENSE file. #include "gquiche/quic/core/quic_connection_id_manager.h" + #include #include "gquiche/quic/core/quic_clock.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_utils.h" +#include "gquiche/common/platform/api/quiche_logging.h" namespace quic { @@ -22,11 +24,13 @@ QuicConnectionIdData::QuicConnectionIdData( namespace { -class RetirePeerIssuedConnectionIdAlarm : public QuicAlarm::Delegate { +class RetirePeerIssuedConnectionIdAlarm + : public QuicAlarm::DelegateWithContext { public: explicit RetirePeerIssuedConnectionIdAlarm( - QuicConnectionIdManagerVisitorInterface* visitor) - : visitor_(visitor) {} + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), visitor_(visitor) {} RetirePeerIssuedConnectionIdAlarm(const RetirePeerIssuedConnectionIdAlarm&) = delete; RetirePeerIssuedConnectionIdAlarm& operator=( @@ -61,13 +65,13 @@ std::vector::iterator FindConnectionIdData( QuicPeerIssuedConnectionIdManager::QuicPeerIssuedConnectionIdManager( size_t active_connection_id_limit, const QuicConnectionId& initial_peer_issued_connection_id, - const QuicClock* clock, - QuicAlarmFactory* alarm_factory, - QuicConnectionIdManagerVisitorInterface* visitor) + const QuicClock* clock, QuicAlarmFactory* alarm_factory, + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context) : active_connection_id_limit_(active_connection_id_limit), clock_(clock), retire_connection_id_alarm_(alarm_factory->CreateAlarm( - new RetirePeerIssuedConnectionIdAlarm(visitor))) { + new RetirePeerIssuedConnectionIdAlarm(visitor, context))) { QUICHE_DCHECK_GE(active_connection_id_limit_, 2u); QUICHE_DCHECK(!initial_peer_issued_connection_id.IsEmpty()); active_connection_id_data_.emplace_backCreateAlarm( - new RetireSelfIssuedConnectionIdAlarmDelegate(this))), + new RetireSelfIssuedConnectionIdAlarmDelegate(this, context))), last_connection_id_(initial_connection_id), next_connection_id_sequence_number_(1u), last_connection_id_consumed_by_self_sequence_number_(0u) { @@ -372,6 +379,12 @@ QuicSelfIssuedConnectionIdManager::GetUnretiredConnectionIds() const { return unretired_ids; } +QuicConnectionId QuicSelfIssuedConnectionIdManager::GetOneActiveConnectionId() + const { + QUICHE_DCHECK(!active_connection_ids_.empty()); + return active_connection_ids_.front().first; +} + void QuicSelfIssuedConnectionIdManager::RetireConnectionId() { if (to_be_retired_connection_ids_.empty()) { QUIC_BUG(quic_bug_12420_1) diff --git a/gquiche/quic/core/quic_connection_id_manager.h b/gquiche/quic/core/quic_connection_id_manager.h index 90b5fcea..9926e12a 100644 --- a/gquiche/quic/core/quic_connection_id_manager.h +++ b/gquiche/quic/core/quic_connection_id_manager.h @@ -60,9 +60,9 @@ class QUIC_EXPORT_PRIVATE QuicPeerIssuedConnectionIdManager { QuicPeerIssuedConnectionIdManager( size_t active_connection_id_limit, const QuicConnectionId& initial_peer_issued_connection_id, - const QuicClock* clock, - QuicAlarmFactory* alarm_factory, - QuicConnectionIdManagerVisitorInterface* visitor); + const QuicClock* clock, QuicAlarmFactory* alarm_factory, + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context); ~QuicPeerIssuedConnectionIdManager(); @@ -123,10 +123,10 @@ class QUIC_EXPORT_PRIVATE QuicSelfIssuedConnectionIdManager { public: QuicSelfIssuedConnectionIdManager( size_t active_connection_id_limit, - const QuicConnectionId& initial_connection_id, - const QuicClock* clock, + const QuicConnectionId& initial_connection_id, const QuicClock* clock, QuicAlarmFactory* alarm_factory, - QuicConnectionIdManagerVisitorInterface* visitor); + QuicConnectionIdManagerVisitorInterface* visitor, + QuicConnectionContext* context); virtual ~QuicSelfIssuedConnectionIdManager(); @@ -139,6 +139,8 @@ class QUIC_EXPORT_PRIVATE QuicSelfIssuedConnectionIdManager { std::vector GetUnretiredConnectionIds() const; + QuicConnectionId GetOneActiveConnectionId() const; + // Called when the retire_connection_id alarm_ fires. Removes the to be // retired connection ID locally. void RetireConnectionId(); diff --git a/gquiche/quic/core/quic_connection_id_manager_test.cc b/gquiche/quic/core/quic_connection_id_manager_test.cc index a64f2282..e454a518 100644 --- a/gquiche/quic/core/quic_connection_id_manager_test.cc +++ b/gquiche/quic/core/quic_connection_id_manager_test.cc @@ -78,11 +78,9 @@ class TestPeerIssuedConnectionIdManagerVisitor class QuicPeerIssuedConnectionIdManagerTest : public QuicTest { public: QuicPeerIssuedConnectionIdManagerTest() - : peer_issued_cid_manager_(/*active_connection_id_limit=*/2, - initial_connection_id_, - &clock_, - &alarm_factory_, - &cid_manager_visitor_) { + : peer_issued_cid_manager_( + /*active_connection_id_limit=*/2, initial_connection_id_, &clock_, + &alarm_factory_, &cid_manager_visitor_, /*context=*/nullptr) { clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); cid_manager_visitor_.SetPeerIssuedConnectionIdManager( &peer_issued_cid_manager_); @@ -538,11 +536,9 @@ class TestSelfIssuedConnectionIdManagerVisitor class QuicSelfIssuedConnectionIdManagerTest : public QuicTest { public: QuicSelfIssuedConnectionIdManagerTest() - : cid_manager_(/*active_connection_id_limit*/ 2, - initial_connection_id_, - &clock_, - &alarm_factory_, - &cid_manager_visitor_) { + : cid_manager_(/*active_connection_id_limit*/ 2, initial_connection_id_, + &clock_, &alarm_factory_, &cid_manager_visitor_, + /*context=*/nullptr) { clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); retire_self_issued_cid_alarm_ = QuicConnectionIdManagerPeer::GetRetireSelfIssuedConnectionIdAlarm( diff --git a/gquiche/quic/core/quic_connection_stats.cc b/gquiche/quic/core/quic_connection_stats.cc index d49c20df..b8d87637 100644 --- a/gquiche/quic/core/quic_connection_stats.cc +++ b/gquiche/quic/core/quic_connection_stats.cc @@ -34,8 +34,9 @@ std::ostream& operator<<(std::ostream& os, const QuicConnectionStats& s) { os << " pto_count: " << s.pto_count; os << " min_rtt_us: " << s.min_rtt_us; os << " srtt_us: " << s.srtt_us; - os << " max_packet_size: " << s.max_packet_size; - os << " max_received_packet_size: " << s.max_received_packet_size; + os << " egress_mtu: " << s.egress_mtu; + os << " max_egress_mtu: " << s.max_egress_mtu; + os << " ingress_mtu: " << s.ingress_mtu; os << " estimated_bandwidth: " << s.estimated_bandwidth; os << " packets_reordered: " << s.packets_reordered; os << " max_sequence_reordering: " << s.max_sequence_reordering; diff --git a/gquiche/quic/core/quic_connection_stats.h b/gquiche/quic/core/quic_connection_stats.h index 01979cf0..d139bd41 100644 --- a/gquiche/quic/core/quic_connection_stats.h +++ b/gquiche/quic/core/quic_connection_stats.h @@ -103,8 +103,14 @@ struct QUIC_EXPORT_PRIVATE QuicConnectionStats { int64_t min_rtt_us = 0; // Minimum RTT in microseconds. int64_t srtt_us = 0; // Smoothed RTT in microseconds. int64_t cwnd_bootstrapping_rtt_us = 0; // RTT used in cwnd_bootstrapping. - QuicByteCount max_packet_size = 0; - QuicByteCount max_received_packet_size = 0; + // The connection's |long_term_mtu_| used for sending packets, populated by + // QuicConnection::GetStats(). + QuicByteCount egress_mtu = 0; + // The maximum |long_term_mtu_| the connection ever used. + QuicByteCount max_egress_mtu = 0; + // Size of the largest packet received from the peer, populated by + // QuicConnection::GetStats(). + QuicByteCount ingress_mtu = 0; QuicBandwidth estimated_bandwidth = QuicBandwidth::Zero(); // Reordering stats for received packets. @@ -209,6 +215,10 @@ struct QUIC_EXPORT_PRIVATE QuicConnectionStats { // which was canceled because the peer migrated again. Such migration is also // counted as invalid peer migration. size_t num_peer_migration_while_validating_default_path = 0; + // Number of NEW_CONNECTION_ID frames sent. + size_t num_new_connection_id_sent = 0; + // Number of RETIRE_CONNECTION_ID frames sent. + size_t num_retire_connection_id_sent = 0; struct QUIC_NO_EXPORT TlsServerOperationStats { bool success = false; diff --git a/gquiche/quic/core/quic_connection_test.cc b/gquiche/quic/core/quic_connection_test.cc index 25e60632..25111411 100644 --- a/gquiche/quic/core/quic_connection_test.cc +++ b/gquiche/quic/core/quic_connection_test.cc @@ -23,6 +23,7 @@ #include "gquiche/quic/core/frames/quic_connection_close_frame.h" #include "gquiche/quic/core/frames/quic_new_connection_id_frame.h" #include "gquiche/quic/core/frames/quic_path_response_frame.h" +#include "gquiche/quic/core/frames/quic_rst_stream_frame.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_constants.h" #include "gquiche/quic/core/quic_error_codes.h" @@ -33,6 +34,7 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/quic_versions.h" +#include "gquiche/quic/platform/api/quic_error_code_wrappers.h" #include "gquiche/quic/platform/api/quic_expect_bug.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_ip_address.h" @@ -177,33 +179,6 @@ class TestConnectionHelper : public QuicConnectionHelperInterface { SimpleBufferAllocator buffer_allocator_; }; -class TestAlarmFactory : public QuicAlarmFactory { - public: - class TestAlarm : public QuicAlarm { - public: - explicit TestAlarm(QuicArenaScopedPtr delegate) - : QuicAlarm(std::move(delegate)) {} - - void SetImpl() override {} - void CancelImpl() override {} - using QuicAlarm::Fire; - }; - - TestAlarmFactory() {} - TestAlarmFactory(const TestAlarmFactory&) = delete; - TestAlarmFactory& operator=(const TestAlarmFactory&) = delete; - - QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override { - return new TestAlarm(QuicArenaScopedPtr(delegate)); - } - - QuicArenaScopedPtr CreateAlarm( - QuicArenaScopedPtr delegate, - QuicConnectionArena* arena) override { - return arena->New(std::move(delegate)); - } -}; - class TestConnection : public QuicConnection { public: TestConnection(QuicConnectionId connection_id, @@ -537,6 +512,15 @@ class TestConnection : public QuicConnection { return false; } + void SendOrQueuePacket(SerializedPacket packet) override { + QuicConnection::SendOrQueuePacket(std::move(packet)); + self_address_on_default_path_while_sending_packet_ = self_address(); + } + + QuicSocketAddress self_address_on_default_path_while_sending_packet() { + return self_address_on_default_path_while_sending_packet_; + } + SimpleDataProducer* producer() { return &producer_; } using QuicConnection::active_effective_peer_migration_type; @@ -563,6 +547,8 @@ class TestConnection : public QuicConnection { SimpleSessionNotifier* notifier_; std::unique_ptr next_effective_peer_addr_; + + QuicSocketAddress self_address_on_default_path_while_sending_packet_; }; enum class AckResponse { kDefer, kImmediate }; @@ -777,7 +763,6 @@ class QuicConnectionTest : public QuicTestWithParam { std::unique_ptr decrypter) { if (connection_.version().KnowsWhichDecrypterToUse()) { connection_.InstallDecrypter(level, std::move(decrypter)); - connection_.RemoveDecrypter(ENCRYPTION_INITIAL); } else { connection_.SetDecrypter(level, std::move(decrypter)); } @@ -824,10 +809,10 @@ class QuicConnectionTest : public QuicTestWithParam { level); } - void ProcessFramesPacketWithAddresses(QuicFrames frames, - QuicSocketAddress self_address, - QuicSocketAddress peer_address, - EncryptionLevel level) { + std::unique_ptr ConstructPacket(QuicFrames frames, + EncryptionLevel level, + char* buffer, + size_t buffer_len) { QUICHE_DCHECK(peer_framer_.HasEncrypterOfEncryptionLevel(level)); peer_creator_.set_encryption_level(level); QuicPacketCreatorPeer::SetSendVersionInPacket( @@ -835,14 +820,23 @@ class QuicConnectionTest : public QuicTestWithParam { level < ENCRYPTION_FORWARD_SECURE && connection_.perspective() == Perspective::IS_SERVER); - char buffer[kMaxOutgoingPacketSize]; SerializedPacket serialized_packet = - QuicPacketCreatorPeer::SerializeAllFrames( - &peer_creator_, frames, buffer, kMaxOutgoingPacketSize); + QuicPacketCreatorPeer::SerializeAllFrames(&peer_creator_, frames, + buffer, buffer_len); + return std::make_unique( + serialized_packet.encrypted_buffer, serialized_packet.encrypted_length, + clock_.Now()); + } + + void ProcessFramesPacketWithAddresses(QuicFrames frames, + QuicSocketAddress self_address, + QuicSocketAddress peer_address, + EncryptionLevel level) { + char buffer[kMaxOutgoingPacketSize]; connection_.ProcessUdpPacket( self_address, peer_address, - QuicReceivedPacket(serialized_packet.encrypted_buffer, - serialized_packet.encrypted_length, clock_.Now())); + *ConstructPacket(std::move(frames), level, buffer, + kMaxOutgoingPacketSize)); if (connection_.GetSendAlarm()->IsSet()) { connection_.GetSendAlarm()->Fire(); } @@ -903,7 +897,6 @@ class QuicConnectionTest : public QuicTestWithParam { connection_.InstallDecrypter( QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_), std::make_unique(0x01)); - connection_.RemoveDecrypter(ENCRYPTION_INITIAL); } else { connection_.SetDecrypter( QuicPacketCreatorPeer::GetEncryptionLevel(&peer_creator_), @@ -1102,12 +1095,9 @@ class QuicConnectionTest : public QuicTestWithParam { MessageStatus SendMessage(absl::string_view message) { connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); - return connection_.SendMessage( - 1, - MakeSpan(connection_.helper()->GetStreamSendBufferAllocator(), message, - &storage), - false); + QuicMemSlice slice(QuicBuffer::Copy( + connection_.helper()->GetStreamSendBufferAllocator(), message)); + return connection_.SendMessage(1, absl::MakeSpan(&slice, 1), false); } void ProcessAckPacket(uint64_t packet_number, QuicAckFrame* frame) { @@ -1179,7 +1169,12 @@ class QuicConnectionTest : public QuicTestWithParam { } if (peer_framer_.version().HasIetfInvariantHeader() && peer_framer_.perspective() == Perspective::IS_SERVER) { - header.destination_connection_id_included = CONNECTION_ID_ABSENT; + if (!connection_.client_connection_id().IsEmpty()) { + header.destination_connection_id = connection_.client_connection_id(); + header.destination_connection_id_included = CONNECTION_ID_PRESENT; + } else { + header.destination_connection_id_included = CONNECTION_ID_ABSENT; + } if (header.version_flag) { header.source_connection_id = connection_id_; header.source_connection_id_included = CONNECTION_ID_PRESENT; @@ -1199,8 +1194,7 @@ class QuicConnectionTest : public QuicTestWithParam { EncryptionLevel level) { QuicPacketHeader header = ConstructPacketHeader(number, level); QuicFrames frames; - if (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) && - VersionHasIetfQuicFrames(version().transport_version) && + if (VersionHasIetfQuicFrames(version().transport_version) && (level == ENCRYPTION_INITIAL || level == ENCRYPTION_HANDSHAKE)) { frames.push_back(QuicFrame(QuicPingFrame())); frames.push_back(QuicFrame(QuicPaddingFrame(100))); @@ -1336,9 +1330,12 @@ class QuicConnectionTest : public QuicTestWithParam { connection_.set_perspective(perspective); if (perspective == Perspective::IS_SERVER) { QuicConfig config; - QuicTagVector connection_options; - connection_options.push_back(kRVCM); - config.SetInitialReceivedConnectionOptions(connection_options); + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicTagVector connection_options; + connection_options.push_back(kRVCM); + config.SetInitialReceivedConnectionOptions(connection_options); + } EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); connection_.SetFromConfig(config); @@ -1429,7 +1426,8 @@ class QuicConnectionTest : public QuicTestWithParam { EXPECT_TRUE(connection_.connected()); } - void PathProbeTestInit(Perspective perspective) { + void PathProbeTestInit(Perspective perspective, + bool receive_new_server_connection_id = true) { set_perspective(perspective); connection_.CreateConnectionIdManager(); EXPECT_EQ(connection_.perspective(), perspective); @@ -1438,6 +1436,9 @@ class QuicConnectionTest : public QuicTestWithParam { } connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); peer_creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); // Prevent packets from being coalesced. EXPECT_CALL(visitor_, GetHandshakeState()) .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); @@ -1463,6 +1464,17 @@ class QuicConnectionTest : public QuicTestWithParam { kPeerAddress, ENCRYPTION_FORWARD_SECURE); EXPECT_EQ(kPeerAddress, connection_.peer_address()); EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + if (perspective == Perspective::IS_CLIENT && + receive_new_server_connection_id && version().HasIetfQuicFrames()) { + QuicNewConnectionIdFrame frame; + frame.connection_id = TestConnectionId(1234); + ASSERT_NE(frame.connection_id, connection_.connection_id()); + frame.stateless_reset_token = + QuicUtils::GenerateStatelessResetToken(frame.connection_id); + frame.retire_prior_to = 0u; + frame.sequence_number = 1u; + connection_.OnNewConnectionIdFrame(frame); + } } void TestClientRetryHandling(bool invalid_retry_tag, @@ -1506,8 +1518,7 @@ class QuicConnectionTest : public QuicTestWithParam { }; // Run all end to end tests with all supported versions. -INSTANTIATE_TEST_SUITE_P(QuicConnectionTests, - QuicConnectionTest, +INSTANTIATE_TEST_SUITE_P(QuicConnectionTests, QuicConnectionTest, ::testing::ValuesIn(GetTestParams()), ::testing::PrintToStringParamName()); @@ -1701,13 +1712,8 @@ TEST_P(QuicConnectionTest, PeerPortChangeAtServer) { EXPECT_CALL(visitor_, OnStreamFrame(_)) .WillOnce(Invoke( [=]() { EXPECT_EQ(kPeerAddress, connection_.peer_address()); })) - .WillOnce(Invoke([=]() { - EXPECT_EQ((GetQuicReloadableFlag(quic_start_peer_migration_earlier) || - !GetParam().version.HasIetfQuicFrames() - ? kNewPeerAddress - : kPeerAddress), - connection_.peer_address()); - })); + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); })); QuicFrames frames; frames.push_back(QuicFrame(frame1_)); ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, @@ -1743,6 +1749,9 @@ TEST_P(QuicConnectionTest, PeerIpAddressChangeAtServer) { QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); // Prevent packets from being coalesced. EXPECT_CALL(visitor_, GetHandshakeState()) .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); @@ -1775,13 +1784,8 @@ TEST_P(QuicConnectionTest, PeerIpAddressChangeAtServer) { EXPECT_CALL(visitor_, OnStreamFrame(_)) .WillOnce(Invoke( [=]() { EXPECT_EQ(kPeerAddress, connection_.peer_address()); })) - .WillOnce(Invoke([=]() { - EXPECT_EQ((GetQuicReloadableFlag(quic_start_peer_migration_earlier) || - !GetParam().version.HasIetfQuicFrames() - ? kNewPeerAddress - : kPeerAddress), - connection_.peer_address()); - })); + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); })); QuicFrames frames; frames.push_back(QuicFrame(frame1_)); ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, @@ -1876,6 +1880,83 @@ TEST_P(QuicConnectionTest, PeerIpAddressChangeAtServer) { EXPECT_EQ(1u, connection_.GetStats().num_validated_peer_migration); } +TEST_P(QuicConnectionTest, PeerIpAddressChangeAtServerWithMissingConnectionId) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + + QuicConnectionId client_cid0 = TestConnectionId(1); + QuicConnectionId client_cid1 = TestConnectionId(3); + QuicConnectionId server_cid1; + SetClientConnectionId(client_cid0); + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Prevent packets from being coalesced. + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + QuicConnectionPeer::SetAddressValidated(&connection_); + + // Sends new server CID to client. + EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) + .WillOnce( + Invoke([&](const QuicConnectionId& cid) { server_cid1 = cid; })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.OnHandshakeComplete(); + + // Clear direct_peer_address. + QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); + // Clear effective_peer_address, it is the same as direct_peer_address for + // this test. + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, + QuicSocketAddress()); + EXPECT_FALSE(connection_.effective_peer_address().IsInitialized()); + + const QuicSocketAddress kNewPeerAddress = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(2); + QuicFrames frames; + frames.push_back(QuicFrame(frame1_)); + ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Send some data to make connection has packets in flight. + connection_.SendStreamData3(); + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Process another packet with a different peer address on server side will + // start connection migration. + peer_creator_.SetServerConnectionId(server_cid1); + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + // Do not propagate OnCanWrite() to session notifier. + EXPECT_CALL(visitor_, OnCanWrite()).Times(AtLeast(1u)); + + QuicFrames frames2; + frames2.push_back(QuicFrame(frame2_)); + ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, + ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); + EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + + // Writing path response & reverse path challenge is blocked due to missing + // client connection ID, i.e., packets_write_attempts is unchanged. + EXPECT_EQ(1u, writer_->packets_write_attempts()); + + // Receives new client CID from client would unblock write. + QuicNewConnectionIdFrame new_cid_frame; + new_cid_frame.connection_id = client_cid1; + new_cid_frame.sequence_number = 1u; + new_cid_frame.retire_prior_to = 0u; + connection_.OnNewConnectionIdFrame(new_cid_frame); + connection_.SendStreamData3(); + + EXPECT_EQ(2u, writer_->packets_write_attempts()); +} + TEST_P(QuicConnectionTest, EffectivePeerAddressChangeAtServer) { set_perspective(Perspective::IS_SERVER); QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); @@ -1884,6 +1965,9 @@ TEST_P(QuicConnectionTest, EffectivePeerAddressChangeAtServer) { QuicConnectionPeer::SetAddressValidated(&connection_); } connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); EXPECT_CALL(visitor_, GetHandshakeState()) .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); @@ -1991,19 +2075,156 @@ TEST_P(QuicConnectionTest, EffectivePeerAddressChangeAtServer) { } } +// Regression test for b/200020764. +TEST_P(QuicConnectionTest, ConnectionMigrationWithPendingPaddingBytes) { + // TODO(haoyuewang) Move these test setup code to a common member function. + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicConnectionPeer::SetPeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetAddressValidated(&connection_); + + // Sends new server CID to client. + QuicConnectionId new_cid; + EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { new_cid = cid; })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + packet_creator->FlushCurrentPacket(); + packet_creator->AddPendingPadding(50u); + const QuicSocketAddress kPeerAddress3 = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/56789); + auto ack_frame = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _)); + EXPECT_CALL(visitor_, OnConnectionMigration(PORT_CHANGE)).Times(1); + ProcessFramesPacketWithAddresses({QuicFrame(&ack_frame)}, kSelfAddress, + kPeerAddress3, ENCRYPTION_FORWARD_SECURE); + if (GetQuicReloadableFlag( + quic_flush_pending_frames_and_padding_bytes_on_migration)) { + // Any pending frames/padding should be flushed before default_path_ is + // temporarily reset. + ASSERT_EQ(connection_.self_address_on_default_path_while_sending_packet() + .host() + .address_family(), + IpAddressFamily::IP_V6); + } else { + ASSERT_EQ(connection_.self_address_on_default_path_while_sending_packet() + .host() + .address_family(), + IpAddressFamily::IP_UNSPEC); + } +} + +// Regression test for b/196208556. +TEST_P(QuicConnectionTest, + ReversePathValidationResponseReceivedFromUnexpectedPeerAddress) { + set_perspective(Perspective::IS_SERVER); + if (!connection_.connection_migration_use_new_cid()) { + return; + } + QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); + connection_.CreateConnectionIdManager(); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + QuicConnectionPeer::SetPeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetEffectivePeerAddress(&connection_, kPeerAddress); + QuicConnectionPeer::SetAddressValidated(&connection_); + EXPECT_EQ(kPeerAddress, connection_.peer_address()); + EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); + + // Sends new server CID to client. + QuicConnectionId new_cid; + EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) + .WillOnce(Invoke([&](const QuicConnectionId& cid) { new_cid = cid; })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); + + // Process a non-probing packet to migrate to path 2 and kick off reverse path + // validation. + EXPECT_CALL(visitor_, OnConnectionMigration(IPV6_TO_IPV4_CHANGE)).Times(1); + const QuicSocketAddress kPeerAddress2 = + QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); + peer_creator_.SetServerConnectionId(new_cid); + ProcessFramesPacketWithAddresses({QuicFrame(QuicPingFrame())}, kSelfAddress, + kPeerAddress2, ENCRYPTION_FORWARD_SECURE); + EXPECT_FALSE(writer_->path_challenge_frames().empty()); + QuicPathFrameBuffer reverse_path_challenge_payload = + writer_->path_challenge_frames().front().data_buffer; + + // Receiveds a packet from path 3 with PATH_RESPONSE frame intended to + // validate path 2 and a non-probing frame. + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + const QuicSocketAddress kPeerAddress3 = + QuicSocketAddress(QuicIpAddress::Loopback6(), /*port=*/56789); + auto ack_frame = InitAckFrame(1); + EXPECT_CALL(visitor_, OnConnectionMigration(IPV4_TO_IPV6_CHANGE)).Times(1); + EXPECT_CALL(visitor_, MaybeSendAddressToken()).WillOnce(Invoke([this]() { + connection_.SendControlFrame( + QuicFrame(new QuicNewTokenFrame(1, "new_token"))); + return true; + })); + ProcessFramesPacketWithAddresses({QuicFrame(new QuicPathResponseFrame( + 0, reverse_path_challenge_payload)), + QuicFrame(&ack_frame)}, + kSelfAddress, kPeerAddress3, + ENCRYPTION_FORWARD_SECURE); + } +} + TEST_P(QuicConnectionTest, ReversePathValidationFailureAtServer) { set_perspective(Perspective::IS_SERVER); - if (!connection_.validate_client_address()) { + if (!connection_.connection_migration_use_new_cid()) { return; } QuicPacketCreatorPeer::SetSendVersionInPacket(creator_, false); EXPECT_EQ(Perspective::IS_SERVER, connection_.perspective()); + SetClientConnectionId(TestConnectionId(1)); + connection_.CreateConnectionIdManager(); connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); // Prevent packets from being coalesced. EXPECT_CALL(visitor_, GetHandshakeState()) .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); QuicConnectionPeer::SetAddressValidated(&connection_); + + QuicConnectionId client_cid0 = connection_.client_connection_id(); + QuicConnectionId client_cid1 = TestConnectionId(2); + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId server_cid1; + // Sends new server CID to client. + EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) + .WillOnce( + Invoke([&](const QuicConnectionId& cid) { server_cid1 = cid; })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); connection_.OnHandshakeComplete(); + // Receives new client CID from client. + QuicNewConnectionIdFrame new_cid_frame; + new_cid_frame.connection_id = client_cid1; + new_cid_frame.sequence_number = 1u; + new_cid_frame.retire_prior_to = 0u; + connection_.OnNewConnectionIdFrame(new_cid_frame); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); // Clear direct_peer_address. QuicConnectionPeer::SetDirectPeerAddress(&connection_, QuicSocketAddress()); @@ -2018,13 +2239,8 @@ TEST_P(QuicConnectionTest, ReversePathValidationFailureAtServer) { EXPECT_CALL(visitor_, OnStreamFrame(_)) .WillOnce(Invoke( [=]() { EXPECT_EQ(kPeerAddress, connection_.peer_address()); })) - .WillOnce(Invoke([=]() { - EXPECT_EQ((GetQuicReloadableFlag(quic_start_peer_migration_earlier) || - !GetParam().version.HasIetfQuicFrames() - ? kNewPeerAddress - : kPeerAddress), - connection_.peer_address()); - })); + .WillOnce(Invoke( + [=]() { EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); })); QuicFrames frames; frames.push_back(QuicFrame(frame1_)); ProcessFramesPacketWithAddresses(frames, kSelfAddress, kPeerAddress, @@ -2043,6 +2259,7 @@ TEST_P(QuicConnectionTest, ReversePathValidationFailureAtServer) { frames2.push_back(QuicFrame(frame2_)); QuicPaddingFrame padding; frames2.push_back(QuicFrame(padding)); + peer_creator_.SetServerConnectionId(server_cid1); ProcessFramesPacketWithAddresses(frames2, kSelfAddress, kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); @@ -2056,6 +2273,15 @@ TEST_P(QuicConnectionTest, ReversePathValidationFailureAtServer) { EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_EQ(default_path->client_connection_id, client_cid1); + EXPECT_EQ(default_path->server_connection_id, server_cid1); + EXPECT_EQ(alternative_path->client_connection_id, client_cid0); + EXPECT_EQ(alternative_path->server_connection_id, server_cid0); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid1); for (size_t i = 0; i < QuicPathValidator::kMaxRetryTimes; ++i) { clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(3 * kInitialRttMs)); @@ -2067,6 +2293,10 @@ TEST_P(QuicConnectionTest, ReversePathValidationFailureAtServer) { EXPECT_EQ(IPV6_TO_IPV4_CHANGE, connection_.active_effective_peer_migration_type()); + // Make sure anti-amplification limit is not reached. + ProcessFramesPacketWithAddresses( + {QuicFrame(QuicPingFrame()), QuicFrame(QuicPaddingFrame())}, kSelfAddress, + kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); SendStreamDataToPeer(1, "foo", 0, NO_FIN, nullptr); EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); @@ -2082,6 +2312,19 @@ TEST_P(QuicConnectionTest, ReversePathValidationFailureAtServer) { EXPECT_EQ(connection_.sent_packet_manager().GetSendAlgorithm(), send_algorithm_); EXPECT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + + // Verify that default_path_ is reverted and alternative_path_ is cleared. + EXPECT_EQ(default_path->client_connection_id, client_cid0); + EXPECT_EQ(default_path->server_connection_id, server_cid0); + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/1u)); + retire_peer_issued_cid_alarm->Fire(); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); } TEST_P(QuicConnectionTest, ReceivePathProbeWithNoAddressChangeAtServer) { @@ -2614,7 +2857,7 @@ TEST_P(QuicConnectionTest, PeerAddressChangeAtClient) { TEST_P(QuicConnectionTest, MaxPacketSize) { EXPECT_EQ(Perspective::IS_CLIENT, connection_.perspective()); - EXPECT_EQ(1350u, connection_.max_packet_length()); + EXPECT_EQ(1250u, connection_.max_packet_length()); } TEST_P(QuicConnectionTest, PeerLowersMaxPacketSize) { @@ -2651,6 +2894,18 @@ TEST_P(QuicConnectionTest, SmallerServerMaxPacketSize) { EXPECT_EQ(1000u, connection.max_packet_length()); } +TEST_P(QuicConnectionTest, LowerServerResponseMtuTest) { + set_perspective(Perspective::IS_SERVER); + connection_.SetMaxPacketLength(1000); + EXPECT_EQ(1000u, connection_.max_packet_length()); + + SetQuicFlag(FLAGS_quic_use_lower_server_response_mtu_for_test, true); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(::testing::AtMost(1)); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(::testing::AtMost(1)); + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_EQ(1250u, connection_.max_packet_length()); +} + TEST_P(QuicConnectionTest, IncreaseServerMaxPacketSize) { set_perspective(Perspective::IS_SERVER); connection_.SetMaxPacketLength(1000); @@ -2853,8 +3108,7 @@ TEST_P(QuicConnectionTest, PacketsOutOfOrderWithAdditionsAndLeastAwaiting) { TEST_P(QuicConnectionTest, RejectUnencryptedStreamData) { // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. if (!IsDefaultTestConfiguration() || - (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) && - VersionHasIetfQuicFrames(version().transport_version))) { + VersionHasIetfQuicFrames(version().transport_version)) { return; } @@ -3186,20 +3440,6 @@ TEST_P(QuicConnectionTest, TooManySentPackets) { ProcessFramePacket(QuicFrame(QuicPingFrame())); - if (!GetQuicReloadableFlag( - quic_close_connection_with_too_many_outstanding_packets)) { - // When the flag is false, the ping packet processed above shouldn't cause - // the connection to close. But the ack packet below will. - EXPECT_TRUE(connection_.connected()); - - // Ack packet 1, which leaves more than the limit outstanding. - EXPECT_CALL(*send_algorithm_, OnCongestionEvent(true, _, _, _, _)); - - // Nack the first packet and ack the rest, leaving a huge gap. - QuicAckFrame frame1 = ConstructAckFrame(num_packets, 1); - ProcessAckPacket(&frame1); - } - TestConnectionCloseQuicErrorCode(QUIC_TOO_MANY_OUTSTANDING_SENT_PACKETS); } @@ -3407,15 +3647,15 @@ TEST_P(QuicConnectionTest, FramePackingNonCryptoThenCrypto) { EXPECT_EQ(0u, connection_.NumQueuedPackets()); EXPECT_FALSE(connection_.HasQueuedData()); - // Parse the last packet and ensure it's the crypto stream frame. - EXPECT_EQ(2u, writer_->frame_count()); - ASSERT_EQ(1u, writer_->padding_frames().size()); + // Parse the last packet and ensure it contains a crypto stream frame. + EXPECT_LE(2u, writer_->frame_count()); + ASSERT_LE(1u, writer_->padding_frames().size()); if (!QuicVersionUsesCryptoFrames(connection_.transport_version())) { ASSERT_EQ(1u, writer_->stream_frames().size()); EXPECT_EQ(QuicUtils::GetCryptoStreamId(connection_.transport_version()), writer_->stream_frames()[0]->stream_id); } else { - EXPECT_EQ(1u, writer_->crypto_frames().size()); + EXPECT_LE(1u, writer_->crypto_frames().size()); } } @@ -3596,7 +3836,7 @@ TEST_P(QuicConnectionTest, LargeSendWithPendingAck) { EXPECT_TRUE(connection_.HasPendingAcks()); // Send data and ensure the ack is bundled. - EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(8); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(9); size_t len = 10000; std::unique_ptr data_array(new char[len]); memset(data_array.get(), '?', len); @@ -3652,11 +3892,9 @@ TEST_P(QuicConnectionTest, OnCanWrite) { TEST_P(QuicConnectionTest, RetransmitOnNack) { QuicPacketNumber last_packet; - QuicByteCount second_packet_size; - SendStreamDataToPeer(3, "foo", 0, NO_FIN, &last_packet); // Packet 1 - second_packet_size = - SendStreamDataToPeer(3, "foos", 3, NO_FIN, &last_packet); // Packet 2 - SendStreamDataToPeer(3, "fooos", 7, NO_FIN, &last_packet); // Packet 3 + SendStreamDataToPeer(3, "foo", 0, NO_FIN, &last_packet); + SendStreamDataToPeer(3, "foos", 3, NO_FIN, &last_packet); + SendStreamDataToPeer(3, "fooos", 7, NO_FIN, &last_packet); EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); @@ -4306,7 +4544,7 @@ TEST_P(QuicConnectionTest, TLP) { } TEST_P(QuicConnectionTest, TailLossProbeDelayForStreamDataInTLPR) { - if (connection_.PtoEnabled()) { + if (connection_.PtoEnabled() || GetQuicReloadableFlag(quic_deprecate_tlpr)) { return; } @@ -4341,7 +4579,7 @@ TEST_P(QuicConnectionTest, TailLossProbeDelayForStreamDataInTLPR) { } TEST_P(QuicConnectionTest, TailLossProbeDelayForNonStreamDataInTLPR) { - if (connection_.PtoEnabled()) { + if (connection_.PtoEnabled() || GetQuicReloadableFlag(quic_deprecate_tlpr)) { return; } @@ -6377,7 +6615,6 @@ TEST_P(QuicConnectionTest, SendDelayedAckDecimationUnlimitedAggregation) { EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); QuicConfig config; QuicTagVector connection_options; - connection_options.push_back(kACKD); // No limit on the number of packets received before sending an ack. connection_options.push_back(kAKDU); config.SetConnectionOptionsToSend(connection_options); @@ -7098,7 +7335,7 @@ TEST_P(QuicConnectionTest, CheckSendStats) { stats.bytes_retransmitted); EXPECT_EQ(3u, stats.packets_retransmitted); EXPECT_EQ(1u, stats.rto_count); - EXPECT_EQ(kDefaultMaxPacketSize, stats.max_packet_size); + EXPECT_EQ(kDefaultMaxPacketSize, stats.egress_mtu); } TEST_P(QuicConnectionTest, ProcessFramesIfPacketClosedConnection) { @@ -7804,8 +8041,7 @@ TEST_P(QuicConnectionTest, ServerReceivesChloOnNonCryptoStream) { EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); ForceProcessFramePacket(QuicFrame(frame1_)); - if (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) && - VersionHasIetfQuicFrames(version().transport_version)) { + if (VersionHasIetfQuicFrames(version().transport_version)) { // INITIAL packet should not contain STREAM frame. TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); } else { @@ -7827,8 +8063,7 @@ TEST_P(QuicConnectionTest, ClientReceivesRejOnNonCryptoStream) { EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)); ForceProcessFramePacket(QuicFrame(frame1_)); - if (GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) && - VersionHasIetfQuicFrames(version().transport_version)) { + if (VersionHasIetfQuicFrames(version().transport_version)) { // INITIAL packet should not contain STREAM frame. TestConnectionCloseQuicErrorCode(IETF_QUIC_PROTOCOL_VIOLATION); } else { @@ -8675,8 +8910,7 @@ TEST_P(QuicConnectionTest, SendMessage) { connection_.SetFromConfig(config); } std::string message(connection_.GetCurrentLargestMessagePayload() * 2, 'a'); - absl::string_view message_data(message); - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); + QuicMemSlice slice; { QuicConnection::ScopedPacketFlusher flusher(&connection_); connection_.SendStreamData3(); @@ -8684,36 +8918,23 @@ TEST_P(QuicConnectionTest, SendMessage) { // get sent, one contains stream frame, and the other only contains the // message frame. EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + slice = MemSliceFromString(absl::string_view( + message.data(), connection_.GetCurrentLargestMessagePayload())); EXPECT_EQ(MESSAGE_STATUS_SUCCESS, - connection_.SendMessage( - 1, - MakeSpan(connection_.helper()->GetStreamSendBufferAllocator(), - absl::string_view( - message_data.data(), - connection_.GetCurrentLargestMessagePayload()), - &storage), - false)); + connection_.SendMessage(1, absl::MakeSpan(&slice, 1), false)); } // Fail to send a message if connection is congestion control blocked. EXPECT_CALL(*send_algorithm_, CanSend(_)).WillOnce(Return(false)); + slice = MemSliceFromString("message"); EXPECT_EQ(MESSAGE_STATUS_BLOCKED, - connection_.SendMessage( - 2, - MakeSpan(connection_.helper()->GetStreamSendBufferAllocator(), - "message", &storage), - false)); + connection_.SendMessage(2, absl::MakeSpan(&slice, 1), false)); // Always fail to send a message which cannot fit into one packet. EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); + slice = MemSliceFromString(absl::string_view( + message.data(), connection_.GetCurrentLargestMessagePayload() + 1)); EXPECT_EQ(MESSAGE_STATUS_TOO_LARGE, - connection_.SendMessage( - 3, - MakeSpan(connection_.helper()->GetStreamSendBufferAllocator(), - absl::string_view( - message_data.data(), - connection_.GetCurrentLargestMessagePayload() + 1), - &storage), - false)); + connection_.SendMessage(3, absl::MakeSpan(&slice, 1), false)); } TEST_P(QuicConnectionTest, GetCurrentLargestMessagePayload) { @@ -8724,7 +8945,7 @@ TEST_P(QuicConnectionTest, GetCurrentLargestMessagePayload) { // that the encryption overhead is constant across versions. connection_.SetEncrypter(ENCRYPTION_INITIAL, std::make_unique(0x00)); - QuicPacketLength expected_largest_payload = 1319; + QuicPacketLength expected_largest_payload = 1219; if (connection_.version().SendsVariableLengthPacketNumberInLongHeader()) { expected_largest_payload += 3; } @@ -8759,7 +8980,7 @@ TEST_P(QuicConnectionTest, GetGuaranteedLargestMessagePayload) { // that the encryption overhead is constant across versions. connection_.SetEncrypter(ENCRYPTION_INITIAL, std::make_unique(0x00)); - QuicPacketLength expected_largest_payload = 1319; + QuicPacketLength expected_largest_payload = 1219; if (connection_.version().HasLongHeaderLengths()) { expected_largest_payload -= 2; } @@ -9934,12 +10155,13 @@ TEST_P(QuicConnectionTest, AntiAmplificationLimit) { EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); // Receives packet 1. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); const size_t anti_amplification_factor = GetQuicFlag(FLAGS_quic_anti_amplification_factor); // Verify now packets can be sent. - for (size_t i = 0; i < anti_amplification_factor; ++i) { + for (size_t i = 1; i < anti_amplification_factor; ++i) { EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); connection_.SendCryptoDataWithString("foo", i * 3); // Verify retransmission alarm is not set if throttled by anti-amplification @@ -9952,10 +10174,11 @@ TEST_P(QuicConnectionTest, AntiAmplificationLimit) { connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3); // Receives packet 2. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); // Verify more packets can be sent. - for (size_t i = anti_amplification_factor; i < anti_amplification_factor * 2; - ++i) { + for (size_t i = anti_amplification_factor + 1; + i < anti_amplification_factor * 2; ++i) { EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); connection_.SendCryptoDataWithString("foo", i * 3); } @@ -9964,6 +10187,7 @@ TEST_P(QuicConnectionTest, AntiAmplificationLimit) { connection_.SendCryptoDataWithString("foo", 2 * anti_amplification_factor * 3); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessPacket(3); // Verify anti-amplification limit is gone after address validation. for (size_t i = 0; i < 100; ++i) { @@ -10000,11 +10224,12 @@ TEST_P(QuicConnectionTest, 3AntiAmplificationLimit) { EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); // Receives packet 1. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); const size_t anti_amplification_factor = 3; // Verify now packets can be sent. - for (size_t i = 0; i < anti_amplification_factor; ++i) { + for (size_t i = 1; i < anti_amplification_factor; ++i) { EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); connection_.SendCryptoDataWithString("foo", i * 3); // Verify retransmission alarm is not set if throttled by anti-amplification @@ -10017,10 +10242,11 @@ TEST_P(QuicConnectionTest, 3AntiAmplificationLimit) { connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3); // Receives packet 2. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); // Verify more packets can be sent. - for (size_t i = anti_amplification_factor; i < anti_amplification_factor * 2; - ++i) { + for (size_t i = anti_amplification_factor + 1; + i < anti_amplification_factor * 2; ++i) { EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); connection_.SendCryptoDataWithString("foo", i * 3); } @@ -10029,6 +10255,7 @@ TEST_P(QuicConnectionTest, 3AntiAmplificationLimit) { connection_.SendCryptoDataWithString("foo", 2 * anti_amplification_factor * 3); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessPacket(3); // Verify anti-amplification limit is gone after address validation. for (size_t i = 0; i < 100; ++i) { @@ -10065,11 +10292,12 @@ TEST_P(QuicConnectionTest, 10AntiAmplificationLimit) { EXPECT_FALSE(connection_.GetRetransmissionAlarm()->IsSet()); // Receives packet 1. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); const size_t anti_amplification_factor = 10; // Verify now packets can be sent. - for (size_t i = 0; i < anti_amplification_factor; ++i) { + for (size_t i = 1; i < anti_amplification_factor; ++i) { EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); connection_.SendCryptoDataWithString("foo", i * 3); // Verify retransmission alarm is not set if throttled by anti-amplification @@ -10082,10 +10310,11 @@ TEST_P(QuicConnectionTest, 10AntiAmplificationLimit) { connection_.SendCryptoDataWithString("foo", anti_amplification_factor * 3); // Receives packet 2. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); // Verify more packets can be sent. - for (size_t i = anti_amplification_factor; i < anti_amplification_factor * 2; - ++i) { + for (size_t i = anti_amplification_factor + 1; + i < anti_amplification_factor * 2; ++i) { EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); connection_.SendCryptoDataWithString("foo", i * 3); } @@ -10094,6 +10323,7 @@ TEST_P(QuicConnectionTest, 10AntiAmplificationLimit) { connection_.SendCryptoDataWithString("foo", 2 * anti_amplification_factor * 3); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessPacket(3); // Verify anti-amplification limit is gone after address validation. for (size_t i = 0; i < 100; ++i) { @@ -10426,10 +10656,8 @@ TEST_P(QuicConnectionTest, MultiplePacketNumberSpacePto) { } void QuicConnectionTest::TestClientRetryHandling( - bool invalid_retry_tag, - bool missing_original_id_in_config, - bool wrong_original_id_in_config, - bool missing_retry_id_in_config, + bool invalid_retry_tag, bool missing_original_id_in_config, + bool wrong_original_id_in_config, bool missing_retry_id_in_config, bool wrong_retry_id_in_config) { if (invalid_retry_tag) { ASSERT_FALSE(missing_original_id_in_config); @@ -10445,16 +10673,16 @@ void QuicConnectionTest::TestClientRetryHandling( } // These values come from draft-ietf-quic-tls Appendix A.4. - char retry_packet_rfcv1[] = { + uint8_t retry_packet_rfcv1[] = { 0xff, 0x00, 0x00, 0x00, 0x01, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x04, 0xa2, 0x65, 0xba, 0x2e, 0xff, 0x4d, 0x82, 0x90, 0x58, 0xfb, 0x3f, 0x0f, 0x24, 0x96, 0xba}; - char retry_packet29[] = { + uint8_t retry_packet29[] = { 0xff, 0xff, 0x00, 0x00, 0x1d, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, 0xb5, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0xd1, 0x69, 0x26, 0xd8, 0x1f, 0x6f, 0x9c, 0xa2, 0x95, 0x3a, 0x8a, 0xa4, 0x57, 0x5e, 0x1e, 0x49}; - char* retry_packet; + uint8_t* retry_packet; size_t retry_packet_length; if (version() == ParsedQuicVersion::RFCv1()) { retry_packet = retry_packet_rfcv1; @@ -10468,19 +10696,21 @@ void QuicConnectionTest::TestClientRetryHandling( return; } - char original_connection_id_bytes[] = {0x83, 0x94, 0xc8, 0xf0, - 0x3e, 0x51, 0x57, 0x08}; - char new_connection_id_bytes[] = {0xf0, 0x67, 0xa5, 0x50, - 0x2a, 0x42, 0x62, 0xb5}; - char retry_token_bytes[] = {0x74, 0x6f, 0x6b, 0x65, 0x6e}; + uint8_t original_connection_id_bytes[] = {0x83, 0x94, 0xc8, 0xf0, + 0x3e, 0x51, 0x57, 0x08}; + uint8_t new_connection_id_bytes[] = {0xf0, 0x67, 0xa5, 0x50, + 0x2a, 0x42, 0x62, 0xb5}; + uint8_t retry_token_bytes[] = {0x74, 0x6f, 0x6b, 0x65, 0x6e}; QuicConnectionId original_connection_id( - original_connection_id_bytes, + reinterpret_cast(original_connection_id_bytes), ABSL_ARRAYSIZE(original_connection_id_bytes)); - QuicConnectionId new_connection_id(new_connection_id_bytes, - ABSL_ARRAYSIZE(new_connection_id_bytes)); + QuicConnectionId new_connection_id( + reinterpret_cast(new_connection_id_bytes), + ABSL_ARRAYSIZE(new_connection_id_bytes)); - std::string retry_token(retry_token_bytes, ABSL_ARRAYSIZE(retry_token_bytes)); + std::string retry_token(reinterpret_cast(retry_token_bytes), + ABSL_ARRAYSIZE(retry_token_bytes)); if (invalid_retry_tag) { // Flip the last bit of the retry packet to prevent the integrity tag @@ -10511,7 +10741,8 @@ void QuicConnectionTest::TestClientRetryHandling( // Process the RETRY packet. connection_.ProcessUdpPacket( kSelfAddress, kPeerAddress, - QuicReceivedPacket(retry_packet, retry_packet_length, clock_.Now())); + QuicReceivedPacket(reinterpret_cast(retry_packet), + retry_packet_length, clock_.Now())); if (invalid_retry_tag) { // Make sure we refuse to process a RETRY with invalid tag. @@ -10846,6 +11077,11 @@ TEST_P(QuicConnectionTest, DonotChangeQueuedAcks) { EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); ProcessPacket(2); ProcessPacket(3); @@ -10901,7 +11137,7 @@ TEST_P(QuicConnectionTest, BundleAckWithImmediateResponse) { connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([this]() { - connection_.SendControlFrame(QuicFrame(new QuicWindowUpdateFrame(1, 0, 0))); + notifier_.WriteOrBufferWindowUpate(0, 0); })); EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); ProcessDataPacket(1); @@ -11179,8 +11415,7 @@ TEST_P(QuicConnectionTest, ProcessUndecryptablePacketsBasedOnEncryptionLevel) { std::make_unique(0x01)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); // Verify all ENCRYPTION_HANDSHAKE packets get processed. - if (!GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) || - !VersionHasIetfQuicFrames(version().transport_version)) { + if (!VersionHasIetfQuicFrames(version().transport_version)) { EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(6); } connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); @@ -11473,21 +11708,82 @@ TEST_P(QuicConnectionTest, CoalscingPacketCausesInfiniteLoop) { connection_.GetRetransmissionAlarm()->Fire(); } -TEST_P(QuicConnectionTest, TestingLiveness) { - const size_t kMinRttMs = 40; - RttStats* rtt_stats = const_cast(manager_->GetRttStats()); - rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), - QuicTime::Delta::Zero(), QuicTime::Zero()); +TEST_P(QuicConnectionTest, ClientAckDelayForAsyncPacketProcessing) { + if (!version().HasIetfQuicFrames()) { + return; + } + // SetFromConfig is always called after construction from InitializeSession. + EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(AnyNumber()); QuicConfig config; + connection_.SetFromConfig(config); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + use_tagging_decrypter(); + connection_.SetEncrypter(ENCRYPTION_INITIAL, + std::make_unique(0x01)); + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x01)); + EXPECT_EQ(0u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); - CryptoHandshakeMessage msg; - std::string error_details; - QuicConfig client_config; - client_config.SetInitialStreamFlowControlWindowToSend( - kInitialStreamFlowControlWindowForTest); - client_config.SetInitialSessionFlowControlWindowToSend( - kInitialSessionFlowControlWindowForTest); + // Received undecryptable HANDSHAKE 2. + ProcessDataPacketAtLevel(2, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + ASSERT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + // Received INITIAL 4 (which is retransmission of INITIAL 1) after 100ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(100)); + ProcessDataPacketAtLevel(4, !kHasStopWaiting, ENCRYPTION_INITIAL); + // Generate HANDSHAKE key. + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x01)); + EXPECT_TRUE(connection_.GetProcessUndecryptablePacketsAlarm()->IsSet()); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x01)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + // Verify HANDSHAKE packet gets processed. + connection_.GetProcessUndecryptablePacketsAlarm()->Fire(); + ASSERT_TRUE(connection_.HasPendingAcks()); + // Send ACKs. + clock_.AdvanceTime(connection_.GetAckAlarm()->deadline() - clock_.Now()); + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + connection_.GetAckAlarm()->Fire(); + ASSERT_FALSE(writer_->ack_frames().empty()); + // Verify the ack_delay_time in the INITIAL ACK frame is 1ms. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(1), + writer_->ack_frames()[0].ack_delay_time); + // Process the coalesced HANDSHAKE packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + auto packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + ASSERT_FALSE(writer_->ack_frames().empty()); + if (GetQuicReloadableFlag( + quic_reset_per_packet_state_for_undecryptable_packets)) { + // Verify the ack_delay_time in the HANDSHAKE ACK frame includes the + // buffering time. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(101), + writer_->ack_frames()[0].ack_delay_time); + } else { + // This ack_delay_time is wrong. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(1), + writer_->ack_frames()[0].ack_delay_time); + } + ASSERT_TRUE(writer_->coalesced_packet() == nullptr); +} + +TEST_P(QuicConnectionTest, TestingLiveness) { + const size_t kMinRttMs = 40; + RttStats* rtt_stats = const_cast(manager_->GetRttStats()); + rtt_stats->UpdateRtt(QuicTime::Delta::FromMilliseconds(kMinRttMs), + QuicTime::Delta::Zero(), QuicTime::Zero()); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + + CryptoHandshakeMessage msg; + std::string error_details; + QuicConfig client_config; + client_config.SetInitialStreamFlowControlWindowToSend( + kInitialStreamFlowControlWindowForTest); + client_config.SetInitialSessionFlowControlWindowToSend( + kInitialSessionFlowControlWindowForTest); client_config.SetIdleNetworkTimeout(QuicTime::Delta::FromSeconds(30)); client_config.ToHandshakeMessage(&msg, connection_.transport_version()); const QuicErrorCode error = @@ -11808,15 +12104,17 @@ TEST_P(QuicConnectionTest, NewPathValidationCancelsPreviousOne) { const QuicSocketAddress kNewSelfAddress2(QuicIpAddress::Any4(), 12346); EXPECT_NE(kNewSelfAddress2, connection_.self_address()); TestPacketWriter new_writer2(version(), &clock_, Perspective::IS_CLIENT); - EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) - .Times(AtLeast(1u)) - .WillOnce(Invoke([&]() { - EXPECT_EQ(1u, new_writer2.packets_write_attempts()); - EXPECT_EQ(1u, new_writer2.path_challenge_frames().size()); - EXPECT_EQ(1u, new_writer2.padding_frames().size()); - EXPECT_EQ(kNewSelfAddress2.host(), - new_writer2.last_write_source_address()); - })); + if (!connection_.connection_migration_use_new_cid()) { + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1u)) + .WillOnce(Invoke([&]() { + EXPECT_EQ(1u, new_writer2.packets_write_attempts()); + EXPECT_EQ(1u, new_writer2.path_challenge_frames().size()); + EXPECT_EQ(1u, new_writer2.padding_frames().size()); + EXPECT_EQ(kNewSelfAddress2.host(), + new_writer2.last_write_source_address()); + })); + } bool success2 = false; connection_.ValidatePath( std::make_unique( @@ -11825,7 +12123,13 @@ TEST_P(QuicConnectionTest, NewPathValidationCancelsPreviousOne) { &connection_, kNewSelfAddress2, connection_.peer_address(), &success2)); EXPECT_FALSE(success); - EXPECT_TRUE(connection_.HasPendingPathValidation()); + if (connection_.connection_migration_use_new_cid()) { + // There is no pening path validation as there is no available connection + // ID. + EXPECT_FALSE(connection_.HasPendingPathValidation()); + } else { + EXPECT_TRUE(connection_.HasPendingPathValidation()); + } } // Regression test for b/182571515. @@ -11910,7 +12214,7 @@ TEST_P(QuicConnectionTest, PathValidationReceivesStatelessReset) { // writer. TEST_P(QuicConnectionTest, SendPathChallengeUsingBlockedNewSocket) { if (!VersionHasIetfQuicFrames(connection_.version().transport_version) || - !connection_.use_path_validator()) { + !connection_.connection_migration_use_new_cid()) { return; } PathProbeTestInit(Perspective::IS_CLIENT); @@ -11920,6 +12224,7 @@ TEST_P(QuicConnectionTest, SendPathChallengeUsingBlockedNewSocket) { new_writer.BlockOnNextWrite(); EXPECT_CALL(visitor_, OnWriteBlocked()).Times(0); EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)) + .Times(AtLeast(1)) .WillOnce(Invoke([&]() { // Even though the socket is blocked, the PATH_CHALLENGE should still be // treated as sent. @@ -11940,7 +12245,9 @@ TEST_P(QuicConnectionTest, SendPathChallengeUsingBlockedNewSocket) { new_writer.SetWritable(); // Write event on the default socket shouldn't make any difference. connection_.OnCanWrite(); - EXPECT_EQ(0u, writer_->packets_write_attempts()); + // A NEW_CONNECTION_ID frame is received in PathProbeTestInit and OnCanWrite + // will write a acking packet. + EXPECT_EQ(1u, writer_->packets_write_attempts()); EXPECT_EQ(1u, new_writer.packets_write_attempts()); } @@ -12125,9 +12432,12 @@ TEST_P(QuicConnectionTest, return; } PathProbeTestInit(Perspective::IS_CLIENT); + // Make sure there is no outstanding ACK_FRAME to write. + connection_.OnCanWrite(); + uint32_t num_packets_write_attempts = writer_->packets_write_attempts(); writer_->SetShouldWriteFail(); - writer_->SetWriteError(EMSGSIZE); + writer_->SetWriteError(QUIC_EMSGSIZE); const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Any4(), 12345); EXPECT_CALL(visitor_, OnConnectionClosed(_, ConnectionCloseSource::FROM_SELF)) .Times(0u); @@ -12143,7 +12453,7 @@ TEST_P(QuicConnectionTest, EXPECT_TRUE(connection_.HasPendingPathValidation()); // Connection shouldn't be closed. EXPECT_TRUE(connection_.connected()); - EXPECT_EQ(1u, writer_->packets_write_attempts()); + EXPECT_EQ(++num_packets_write_attempts, writer_->packets_write_attempts()); EXPECT_EQ(1u, writer_->path_challenge_frames().size()); EXPECT_EQ(1u, writer_->padding_frames().size()); EXPECT_EQ(kNewPeerAddress, writer_->last_write_peer_address()); @@ -12518,7 +12828,7 @@ TEST_P(QuicConnectionTest, CoalescerHandlesInitialKeyDiscard) { connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, std::make_unique(0x02)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); - connection_.SendCryptoDataWithString(std::string(1300, 'a'), 0); + connection_.SendCryptoDataWithString(std::string(1200, 'a'), 0); // Verify this packet is on hold. EXPECT_EQ(0u, writer_->packets_write_attempts()); } @@ -12545,14 +12855,10 @@ TEST_P(QuicConnectionTest, ZeroRttRejectionAndMissingInitialKeys) { connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, std::make_unique(0x03)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); - connection_.SendCryptoStreamData(); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, std::make_unique(0x04)); connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); - if (!GetQuicReloadableFlag(quic_donot_write_mid_packet_processing)) { - // Retransmit rejected 0-RTT packets. - connection_.OnCanWrite(); - } // Advance INITIAL ack delay to trigger initial ACK to be sent AFTER // the retransmission of rejected 0-RTT packets while the HANDSHAKE // packet is still in the coalescer, such that the INITIAL key gets @@ -12564,7 +12870,7 @@ TEST_P(QuicConnectionTest, ZeroRttRejectionAndMissingInitialKeys) { use_tagging_decrypter(); connection_.SetEncrypter(ENCRYPTION_INITIAL, std::make_unique(0x01)); - connection_.SendCryptoStreamData(); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); // Send 0-RTT packet. connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, std::make_unique(0x02)); @@ -13468,17 +13774,34 @@ TEST_P(QuicConnectionTest, ServerHelloGetsReordered) { } TEST_P(QuicConnectionTest, MigratePath) { + EXPECT_CALL(visitor_, GetHandshakeState()) + .Times(testing::AtMost(2)) + .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); EXPECT_CALL(visitor_, OnPathDegrading()); connection_.OnPathDegradingDetected(); const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Any4(), 12345); EXPECT_NE(kNewSelfAddress, connection_.self_address()); + + // Buffer a packet. + EXPECT_CALL(visitor_, OnWriteBlocked()).Times(1); + writer_->SetWriteBlocked(); + connection_.SendMtuDiscoveryPacket(kMaxOutgoingPacketSize); + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + TestPacketWriter new_writer(version(), &clock_, Perspective::IS_CLIENT); EXPECT_CALL(visitor_, OnForwardProgressMadeAfterPathDegrading()); connection_.MigratePath(kNewSelfAddress, connection_.peer_address(), &new_writer, /*owns_writer=*/false); + EXPECT_EQ(kNewSelfAddress, connection_.self_address()); EXPECT_EQ(&new_writer, QuicConnectionPeer::GetWriter(&connection_)); EXPECT_FALSE(connection_.IsPathDegrading()); + // Buffered packet on the old path should be discarded. + if (connection_.connection_migration_use_new_cid()) { + EXPECT_EQ(0u, connection_.NumQueuedPackets()); + } else { + EXPECT_EQ(1u, connection_.NumQueuedPackets()); + } } TEST_P(QuicConnectionTest, MigrateToNewPathDuringProbing) { @@ -13516,6 +13839,8 @@ TEST_P(QuicConnectionTest, SingleAckInPacket) { connection_.RemoveEncrypter(ENCRYPTION_INITIAL); connection_.NeuterUnencryptedPackets(); connection_.OnHandshakeComplete(); + EXPECT_CALL(visitor_, GetHandshakeState()) + .WillRepeatedly(Return(HANDSHAKE_COMPLETE)); EXPECT_CALL(visitor_, OnStreamFrame(_)).WillOnce(Invoke([=]() { connection_.SendStreamData3(); @@ -13592,7 +13917,6 @@ TEST_P(QuicConnectionTest, NewTokenFrameInstigateAcks) { if (!version().HasIetfQuicFrames()) { return; } - SetQuicReloadableFlag(quic_enable_token_based_address_validation, true); EXPECT_CALL(visitor_, OnSuccessfulVersionNegotiation(_)); QuicNewTokenFrame* new_token = new QuicNewTokenFrame(); @@ -13607,7 +13931,6 @@ TEST_P(QuicConnectionTest, ServerClosesConnectionOnNewTokenFrame) { if (!version().HasIetfQuicFrames()) { return; } - SetQuicReloadableFlag(quic_enable_token_based_address_validation, true); set_perspective(Perspective::IS_SERVER); QuicNewTokenFrame* new_token = new QuicNewTokenFrame(); EXPECT_CALL(visitor_, OnNewTokenReceived(_)).Times(0); @@ -13710,8 +14033,7 @@ TEST_P(QuicConnectionTest, // Regression test for b/177312785 TEST_P(QuicConnectionTest, PeerMigrateBeforeHandshakeConfirm) { - if (!VersionHasIetfQuicFrames(version().transport_version) || - !GetQuicReloadableFlag(quic_start_peer_migration_earlier)) { + if (!VersionHasIetfQuicFrames(version().transport_version)) { return; } set_perspective(Perspective::IS_SERVER); @@ -13779,11 +14101,34 @@ TEST_P(QuicConnectionTest, TryToFlushAckWithAckQueued) { TEST_P(QuicConnectionTest, PathChallengeBeforePeerIpAddressChangeAtServer) { set_perspective(Perspective::IS_SERVER); - if (!connection_.validate_client_address()) { + if (!connection_.connection_migration_use_new_cid()) { return; } PathProbeTestInit(Perspective::IS_SERVER); + SetClientConnectionId(TestConnectionId(1)); + connection_.CreateConnectionIdManager(); + + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId client_cid0 = connection_.client_connection_id(); + QuicConnectionId client_cid1 = TestConnectionId(2); + QuicConnectionId server_cid1; + // Sends new server CID to client. + EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) + .WillOnce( + Invoke([&](const QuicConnectionId& cid) { server_cid1 = cid; })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.MaybeSendConnectionIdToClient(); + // Receives new client CID from client. + QuicNewConnectionIdFrame new_cid_frame; + new_cid_frame.connection_id = client_cid1; + new_cid_frame.sequence_number = 1u; + new_cid_frame.retire_prior_to = 0u; + connection_.OnNewConnectionIdFrame(new_cid_frame); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); + peer_creator_.SetServerConnectionId(server_cid1); const QuicSocketAddress kNewPeerAddress = QuicSocketAddress(QuicIpAddress::Loopback4(), /*port=*/23456); QuicPathFrameBuffer path_challenge_payload{0, 1, 2, 3, 4, 5, 6, 7}; @@ -13807,6 +14152,15 @@ TEST_P(QuicConnectionTest, PathChallengeBeforePeerIpAddressChangeAtServer) { EXPECT_EQ(kPeerAddress, connection_.peer_address()); EXPECT_EQ(kPeerAddress, connection_.effective_peer_address()); EXPECT_TRUE(connection_.HasPendingPathValidation()); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_EQ(default_path->client_connection_id, client_cid0); + EXPECT_EQ(default_path->server_connection_id, server_cid0); + EXPECT_EQ(alternative_path->client_connection_id, client_cid1); + EXPECT_EQ(alternative_path->server_connection_id, server_cid1); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid0); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); // Process another packet with a different peer address on server side will // start connection migration. @@ -13843,6 +14197,14 @@ TEST_P(QuicConnectionTest, PathChallengeBeforePeerIpAddressChangeAtServer) { EXPECT_CALL(*send_algorithm_, InRecovery()).Times(AnyNumber()); EXPECT_CALL(*send_algorithm_, PopulateConnectionStats(_)).Times(AnyNumber()); connection_.SetSendAlgorithm(send_algorithm_); + EXPECT_EQ(default_path->client_connection_id, client_cid1); + EXPECT_EQ(default_path->server_connection_id, server_cid1); + // The previous default path is kept as alternative path before reverse path + // validation finishes. + EXPECT_EQ(alternative_path->client_connection_id, client_cid0); + EXPECT_EQ(alternative_path->server_connection_id, server_cid0); + EXPECT_EQ(packet_creator->GetDestinationConnectionId(), client_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid1); EXPECT_EQ(kNewPeerAddress, connection_.peer_address()); EXPECT_EQ(kNewPeerAddress, connection_.effective_peer_address()); @@ -13866,10 +14228,21 @@ TEST_P(QuicConnectionTest, PathChallengeBeforePeerIpAddressChangeAtServer) { ProcessFramesPacketWithAddresses(frames3, kSelfAddress, kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); EXPECT_EQ(NO_CHANGE, connection_.active_effective_peer_migration_type()); + // Verify that alternative_path_ is cleared and the peer CID is retired. + EXPECT_TRUE(alternative_path->client_connection_id.IsEmpty()); + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + auto* retire_peer_issued_cid_alarm = + connection_.GetRetirePeerIssuedConnectionIdAlarm(); + ASSERT_TRUE(retire_peer_issued_cid_alarm->IsSet()); + EXPECT_CALL(visitor_, SendRetireConnectionId(/*sequence_number=*/0u)); + retire_peer_issued_cid_alarm->Fire(); // Verify the anti-amplification limit is lifted by sending a packet larger // than the anti-amplification limit. EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + EXPECT_CALL(*send_algorithm_, PacingRate(_)) + .WillRepeatedly(Return(QuicBandwidth::Zero())); connection_.SendCryptoDataWithString(std::string(1200, 'a'), 0); EXPECT_EQ(1u, connection_.GetStats().num_validated_peer_migration); } @@ -13877,12 +14250,25 @@ TEST_P(QuicConnectionTest, PathChallengeBeforePeerIpAddressChangeAtServer) { TEST_P(QuicConnectionTest, PathValidationSucceedsBeforePeerIpAddressChangeAtServer) { set_perspective(Perspective::IS_SERVER); - if (!connection_.validate_client_address()) { + if (!connection_.connection_migration_use_new_cid()) { return; } PathProbeTestInit(Perspective::IS_SERVER); + connection_.CreateConnectionIdManager(); + + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId server_cid1; + // Sends new server CID to client. + EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) + .WillOnce( + Invoke([&](const QuicConnectionId& cid) { server_cid1 = cid; })); + EXPECT_CALL(visitor_, SendNewConnectionId(_)); + connection_.MaybeSendConnectionIdToClient(); + auto* packet_creator = QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); // Receive probing packet with new peer address. + peer_creator_.SetServerConnectionId(server_cid1); const QuicSocketAddress kNewPeerAddress(QuicIpAddress::Loopback4(), /*port=*/23456); QuicPathFrameBuffer payload; @@ -13907,6 +14293,12 @@ TEST_P(QuicConnectionTest, ProcessFramesPacketWithAddresses(frames1, kSelfAddress, kNewPeerAddress, ENCRYPTION_FORWARD_SECURE); EXPECT_TRUE(connection_.HasPendingPathValidation()); + const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); + const auto* alternative_path = + QuicConnectionPeer::GetAlternativePath(&connection_); + EXPECT_EQ(default_path->server_connection_id, server_cid0); + EXPECT_EQ(alternative_path->server_connection_id, server_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid0); // Receive PATH_RESPONSE should mark the new peer address validated. QuicFrames frames3; @@ -13944,6 +14336,12 @@ TEST_P(QuicConnectionTest, EXPECT_NE(connection_.sent_packet_manager().GetSendAlgorithm(), send_algorithm_); + EXPECT_EQ(default_path->server_connection_id, server_cid1); + EXPECT_EQ(packet_creator->GetSourceConnectionId(), server_cid1); + // Verify that alternative_path_ is cleared. + EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + // Switch to use the mock send algorithm. send_algorithm_ = new StrictMock(); EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); @@ -14030,12 +14428,18 @@ TEST_P(QuicConnectionTest, TEST_P(QuicConnectionTest, PathValidationFailedOnClientDueToLackOfServerConnectionId) { - if (!connection_.support_multiple_connection_ids() || - !connection_.use_connection_id_on_default_path()) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicConfig config; + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + config.SetConnectionOptionsToSend({kRVCM}); + } + if (!connection_.connection_migration_use_new_cid()) { return; } - QuicConnectionPeer::EnableConnectionMigrationUseNewCID(&connection_); - PathProbeTestInit(Perspective::IS_CLIENT); + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); const QuicSocketAddress kNewSelfAddress(QuicIpAddress::Loopback4(), /*port=*/34567); @@ -14052,29 +14456,40 @@ TEST_P(QuicConnectionTest, TEST_P(QuicConnectionTest, PathValidationFailedOnClientDueToLackOfClientConnectionIdTheSecondTime) { - if (!connection_.support_multiple_connection_ids() || - !connection_.use_connection_id_on_default_path()) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { return; } - - QuicConnectionPeer::EnableConnectionMigrationUseNewCID(&connection_); - PathProbeTestInit(Perspective::IS_CLIENT); + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); SetClientConnectionId(TestConnectionId(1)); // Make sure server connection ID is available for the 1st validation. + QuicConnectionId server_cid0 = connection_.connection_id(); + QuicConnectionId server_cid1 = TestConnectionId(2); + QuicConnectionId server_cid2 = TestConnectionId(4); + QuicConnectionId client_cid1; QuicNewConnectionIdFrame frame1; - frame1.connection_id = TestConnectionId(2); + frame1.connection_id = server_cid1; frame1.sequence_number = 1u; frame1.retire_prior_to = 0u; frame1.stateless_reset_token = QuicUtils::GenerateStatelessResetToken(frame1.connection_id); connection_.OnNewConnectionIdFrame(frame1); + const auto* packet_creator = + QuicConnectionPeer::GetPacketCreator(&connection_); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), server_cid0); // Client will issue a new client connection ID to server. - QuicConnectionId new_client_connection_id; EXPECT_CALL(visitor_, SendNewConnectionId(_)) .WillOnce(Invoke([&](const QuicNewConnectionIdFrame& frame) { - new_client_connection_id = frame.connection_id; + client_cid1 = frame.connection_id; })); const QuicSocketAddress kSelfAddress1(QuicIpAddress::Any4(), 12345); @@ -14092,15 +14507,15 @@ TEST_P(QuicConnectionTest, &new_writer, /*owns_writer=*/false)); QuicConnectionPeer::RetirePeerIssuedConnectionIdsNoLongerOnPath(&connection_); const auto* default_path = QuicConnectionPeer::GetDefaultPath(&connection_); - EXPECT_EQ(default_path->client_connection_id, new_client_connection_id); - EXPECT_EQ(default_path->server_connection_id, frame1.connection_id); + EXPECT_EQ(default_path->client_connection_id, client_cid1); + EXPECT_EQ(default_path->server_connection_id, server_cid1); EXPECT_EQ(default_path->stateless_reset_token, frame1.stateless_reset_token); - EXPECT_TRUE(default_path->stateless_reset_token_received); const auto* alternative_path = QuicConnectionPeer::GetAlternativePath(&connection_); EXPECT_TRUE(alternative_path->client_connection_id.IsEmpty()); EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); - EXPECT_FALSE(alternative_path->stateless_reset_token_received); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); + ASSERT_EQ(packet_creator->GetDestinationConnectionId(), server_cid1); // Client will retire server connection ID on old default_path. auto* retire_peer_issued_cid_alarm = @@ -14111,7 +14526,7 @@ TEST_P(QuicConnectionTest, // Another server connection ID is available to client. QuicNewConnectionIdFrame frame2; - frame2.connection_id = TestConnectionId(4); + frame2.connection_id = server_cid2; frame2.sequence_number = 2u; frame2.retire_prior_to = 1u; frame2.stateless_reset_token = @@ -14132,12 +14547,16 @@ TEST_P(QuicConnectionTest, } TEST_P(QuicConnectionTest, ServerConnectionIdRetiredUponPathValidationFailure) { - if (!connection_.support_multiple_connection_ids() || - !connection_.use_connection_id_on_default_path()) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { return; } - - QuicConnectionPeer::EnableConnectionMigrationUseNewCID(&connection_); PathProbeTestInit(Perspective::IS_CLIENT); // Make sure server connection ID is available for validation. @@ -14166,7 +14585,7 @@ TEST_P(QuicConnectionTest, ServerConnectionIdRetiredUponPathValidationFailure) { QuicConnectionPeer::GetAlternativePath(&connection_); EXPECT_TRUE(alternative_path->client_connection_id.IsEmpty()); EXPECT_TRUE(alternative_path->server_connection_id.IsEmpty()); - EXPECT_FALSE(alternative_path->stateless_reset_token_received); + EXPECT_FALSE(alternative_path->stateless_reset_token.has_value()); // Client will retire server connection ID on alternative_path. auto* retire_peer_issued_cid_alarm = @@ -14178,12 +14597,18 @@ TEST_P(QuicConnectionTest, ServerConnectionIdRetiredUponPathValidationFailure) { TEST_P(QuicConnectionTest, MigratePathDirectlyFailedDueToLackOfServerConnectionId) { - if (!connection_.support_multiple_connection_ids() || - !connection_.use_connection_id_on_default_path()) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { return; } - QuicConnectionPeer::EnableConnectionMigrationUseNewCID(&connection_); - PathProbeTestInit(Perspective::IS_CLIENT); + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); const QuicSocketAddress kSelfAddress1(QuicIpAddress::Any4(), 12345); ASSERT_NE(kSelfAddress1, connection_.self_address()); @@ -14195,12 +14620,18 @@ TEST_P(QuicConnectionTest, TEST_P(QuicConnectionTest, MigratePathDirectlyFailedDueToLackOfClientConnectionIdTheSecondTime) { - if (!connection_.support_multiple_connection_ids() || - !connection_.use_connection_id_on_default_path()) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { return; } - QuicConnectionPeer::EnableConnectionMigrationUseNewCID(&connection_); - PathProbeTestInit(Perspective::IS_CLIENT); + PathProbeTestInit(Perspective::IS_CLIENT, + /*receive_new_server_connection_id=*/false); SetClientConnectionId(TestConnectionId(1)); // Make sure server connection ID is available for the 1st migration. @@ -14231,7 +14662,6 @@ TEST_P(QuicConnectionTest, EXPECT_EQ(default_path->client_connection_id, new_client_connection_id); EXPECT_EQ(default_path->server_connection_id, frame1.connection_id); EXPECT_EQ(default_path->stateless_reset_token, frame1.stateless_reset_token); - EXPECT_TRUE(default_path->stateless_reset_token_received); // Client will retire server connection ID on old default_path. auto* retire_peer_issued_cid_alarm = @@ -14253,9 +14683,11 @@ TEST_P(QuicConnectionTest, // would fail due to lack of client connection ID. const QuicSocketAddress kSelfAddress2(QuicIpAddress::Loopback4(), /*port=*/45678); - ASSERT_FALSE(connection_.MigratePath(kSelfAddress2, - connection_.peer_address(), &new_writer, - /*owns_writer=*/false)); + auto new_writer2 = std::make_unique(version(), &clock_, + Perspective::IS_CLIENT); + ASSERT_FALSE(connection_.MigratePath( + kSelfAddress2, connection_.peer_address(), new_writer2.release(), + /*owns_writer=*/true)); } TEST_P(QuicConnectionTest, @@ -14263,7 +14695,6 @@ TEST_P(QuicConnectionTest, if (!version().HasIetfQuicFrames()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); set_perspective(Perspective::IS_SERVER); ASSERT_TRUE(connection_.client_connection_id().IsEmpty()); @@ -14288,7 +14719,6 @@ TEST_P(QuicConnectionTest, NewConnectionIdFrameResultsInError) { if (!version().HasIetfQuicFrames()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); connection_.CreateConnectionIdManager(); ASSERT_FALSE(connection_.connection_id().IsEmpty()); @@ -14313,7 +14743,6 @@ TEST_P(QuicConnectionTest, if (!version().HasIetfQuicFrames()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); connection_.CreateConnectionIdManager(); QuicNewConnectionIdFrame frame; @@ -14350,7 +14779,6 @@ TEST_P(QuicConnectionTest, if (!version().HasIetfQuicFrames()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); set_perspective(Perspective::IS_SERVER); SetClientConnectionId(TestConnectionId(0)); @@ -14387,10 +14815,9 @@ TEST_P( QuicConnectionTest, ReplacePeerIssuedConnectionIdOnBothPathsTriggeredByNewConnectionIdFrame) { if (!version().HasIetfQuicFrames() || !connection_.use_path_validator() || - !connection_.use_connection_id_on_default_path()) { + !connection_.count_bytes_on_alternative_path_separately()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); PathProbeTestInit(Perspective::IS_SERVER); SetClientConnectionId(TestConnectionId(0)); @@ -14446,10 +14873,9 @@ TEST_P( TEST_P(QuicConnectionTest, CloseConnectionAfterReceiveRetireConnectionIdWhenNoCIDIssued) { if (!version().HasIetfQuicFrames() || - GetQuicReloadableFlag(quic_use_connection_id_on_default_path)) { + !connection_.connection_migration_use_new_cid()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); set_perspective(Perspective::IS_SERVER); EXPECT_CALL(visitor_, BeforeConnectionCloseSent()); @@ -14467,10 +14893,9 @@ TEST_P(QuicConnectionTest, TEST_P(QuicConnectionTest, RetireConnectionIdFrameResultsInError) { if (!version().HasIetfQuicFrames() || - GetQuicReloadableFlag(quic_use_connection_id_on_default_path)) { + !connection_.connection_migration_use_new_cid()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); set_perspective(Perspective::IS_SERVER); connection_.CreateConnectionIdManager(); @@ -14493,7 +14918,9 @@ TEST_P(QuicConnectionTest, RetireConnectionIdFrameResultsInError) { TEST_P(QuicConnectionTest, ServerRetireSelfIssuedConnectionIdWithoutSendingNewConnectionIdBefore) { - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); + if (!version().HasIetfQuicFrames()) { + return; + } set_perspective(Perspective::IS_SERVER); connection_.CreateConnectionIdManager(); @@ -14504,26 +14931,24 @@ TEST_P(QuicConnectionTest, QuicConnectionId cid0 = connection_id_; QuicRetireConnectionIdFrame frame; frame.sequence_number = 0u; - if (!GetQuicReloadableFlag(quic_use_connection_id_on_default_path)) { + if (connection_.connection_migration_use_new_cid()) { EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)).Times(2); EXPECT_CALL(visitor_, SendNewConnectionId(_)).Times(2); } EXPECT_TRUE(connection_.OnRetireConnectionIdFrame(frame)); - if (!GetQuicReloadableFlag(quic_use_connection_id_on_default_path)) { - ASSERT_TRUE(retire_self_issued_cid_alarm->IsSet()); - // cid0 is retired when the retire CID alarm fires. - if (!GetQuicReloadableFlag(quic_use_connection_id_on_default_path)) - EXPECT_CALL(visitor_, OnServerConnectionIdRetired(cid0)); - retire_self_issued_cid_alarm->Fire(); - } } TEST_P(QuicConnectionTest, ServerRetireSelfIssuedConnectionId) { - if (!version().HasIetfQuicFrames() || - GetQuicReloadableFlag(quic_use_connection_id_on_default_path)) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!connection_.connection_migration_use_new_cid()) { return; } - QuicConnectionPeer::EnableMultipleConnectionIdSupport(&connection_); set_perspective(Perspective::IS_SERVER); connection_.CreateConnectionIdManager(); QuicConnectionId recorded_cid; @@ -14533,7 +14958,10 @@ TEST_P(QuicConnectionTest, ServerRetireSelfIssuedConnectionId) { QuicConnectionId cid0 = connection_id_; QuicConnectionId cid1; QuicConnectionId cid2; + EXPECT_EQ(connection_.connection_id(), cid0); + EXPECT_EQ(connection_.GetOneActiveServerConnectionId(), cid0); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) .WillOnce(Invoke(cid_recorder)); EXPECT_CALL(visitor_, SendNewConnectionId(_)); @@ -14544,29 +14972,67 @@ TEST_P(QuicConnectionTest, ServerRetireSelfIssuedConnectionId) { connection_.GetRetireSelfIssuedConnectionIdAlarm(); ASSERT_FALSE(retire_self_issued_cid_alarm->IsSet()); - QuicRetireConnectionIdFrame frame; - frame.sequence_number = 0u; + // Generate three packets with different connection IDs that will arrive out + // of order (2, 1, 3) later. + char buffers[3][kMaxOutgoingPacketSize]; + // Destination connection ID of packet1 is cid0. + auto packet1 = + ConstructPacket({QuicFrame(QuicPingFrame())}, ENCRYPTION_FORWARD_SECURE, + buffers[0], kMaxOutgoingPacketSize); + peer_creator_.SetServerConnectionId(cid1); + auto retire_cid_frame = std::make_unique(); + retire_cid_frame->sequence_number = 0u; + // Destination connection ID of packet2 is cid1. + auto packet2 = ConstructPacket({QuicFrame(retire_cid_frame.release())}, + ENCRYPTION_FORWARD_SECURE, buffers[1], + kMaxOutgoingPacketSize); + // Destination connection ID of packet3 is cid1. + auto packet3 = + ConstructPacket({QuicFrame(QuicPingFrame())}, ENCRYPTION_FORWARD_SECURE, + buffers[2], kMaxOutgoingPacketSize); + + // Packet2 with RetireConnectionId frame trigers sending NewConnectionId + // immediately. EXPECT_CALL(visitor_, OnServerConnectionIdIssued(_)) .WillOnce(Invoke(cid_recorder)); - // RetireConnectionId trigers sending NewConnectionId immediately. EXPECT_CALL(visitor_, SendNewConnectionId(_)); - EXPECT_TRUE(connection_.OnRetireConnectionIdFrame(frame)); + peer_creator_.SetServerConnectionId(cid1); + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *packet2); cid2 = recorded_cid; // cid0 is not retired immediately. EXPECT_THAT(connection_.GetActiveServerConnectionIds(), ElementsAre(cid0, cid1, cid2)); ASSERT_TRUE(retire_self_issued_cid_alarm->IsSet()); + EXPECT_EQ(connection_.connection_id(), cid1); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid0 || + connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); + + // Packet1 updates the connection ID on the default path but not the active + // connection ID. + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *packet1); + EXPECT_EQ(connection_.connection_id(), cid0); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid0 || + connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); + // cid0 is retired when the retire CID alarm fires. EXPECT_CALL(visitor_, OnServerConnectionIdRetired(cid0)); retire_self_issued_cid_alarm->Fire(); EXPECT_THAT(connection_.GetActiveServerConnectionIds(), ElementsAre(cid1, cid2)); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); + + // Packet3 updates the connection ID on the default path. + connection_.ProcessUdpPacket(kSelfAddress, kPeerAddress, *packet3); + EXPECT_EQ(connection_.connection_id(), cid1); + EXPECT_TRUE(connection_.GetOneActiveServerConnectionId() == cid1 || + connection_.GetOneActiveServerConnectionId() == cid2); } TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoAlternativePath) { - if (!version().HasIetfQuicFrames() || - !connection_.support_multiple_connection_ids() || - !connection_.use_connection_id_on_default_path()) { + if (!version().HasIetfQuicFrames()) { return; } set_perspective(Perspective::IS_SERVER); @@ -14582,7 +15048,7 @@ TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoAlternativePath) { alternative_path->peer_address = QuicSocketAddress(new_host, 12345); alternative_path->server_connection_id = TestConnectionId(3); ASSERT_TRUE(alternative_path->client_connection_id.IsEmpty()); - ASSERT_FALSE(alternative_path->stateless_reset_token_received); + ASSERT_FALSE(alternative_path->stateless_reset_token.has_value()); QuicNewConnectionIdFrame frame; frame.sequence_number = 1u; @@ -14597,13 +15063,10 @@ TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoAlternativePath) { ASSERT_EQ(alternative_path->client_connection_id, frame.connection_id); ASSERT_EQ(alternative_path->stateless_reset_token, frame.stateless_reset_token); - ASSERT_TRUE(alternative_path->stateless_reset_token_received); } TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoDefaultPath) { - if (!version().HasIetfQuicFrames() || - !connection_.support_multiple_connection_ids() || - !connection_.use_connection_id_on_default_path()) { + if (!version().HasIetfQuicFrames()) { return; } set_perspective(Perspective::IS_SERVER); @@ -14626,7 +15089,7 @@ TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoDefaultPath) { ASSERT_FALSE(default_path->validated); ASSERT_TRUE(default_path->client_connection_id.IsEmpty()); - ASSERT_FALSE(default_path->stateless_reset_token_received); + ASSERT_FALSE(default_path->stateless_reset_token.has_value()); QuicNewConnectionIdFrame frame; frame.sequence_number = 1u; @@ -14640,13 +15103,11 @@ TEST_P(QuicConnectionTest, PatchMissingClientConnectionIdOntoDefaultPath) { ASSERT_EQ(default_path->client_connection_id, frame.connection_id); ASSERT_EQ(default_path->stateless_reset_token, frame.stateless_reset_token); - ASSERT_TRUE(default_path->stateless_reset_token_received); ASSERT_EQ(packet_creator->GetDestinationConnectionId(), frame.connection_id); } TEST_P(QuicConnectionTest, ShouldGeneratePacketBlockedByMissingConnectionId) { - if (!version().HasIetfQuicFrames() || - !connection_.support_multiple_connection_ids()) { + if (!version().HasIetfQuicFrames()) { return; } set_perspective(Perspective::IS_SERVER); @@ -14690,6 +15151,9 @@ TEST_P(QuicConnectionTest, LostDataThenGetAcknowledged) { QuicConnectionPeer::SetAddressValidated(&connection_); } connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Discard INITIAL key. + connection_.RemoveEncrypter(ENCRYPTION_INITIAL); + connection_.NeuterUnencryptedPackets(); EXPECT_CALL(visitor_, GetHandshakeState()) .WillRepeatedly(Return(HANDSHAKE_CONFIRMED)); @@ -14718,23 +15182,13 @@ TEST_P(QuicConnectionTest, LostDataThenGetAcknowledged) { InvokeWithoutArgs(¬ifier_, &SimpleSessionNotifier::OnCanWrite)); QuicIpAddress ip_address; ASSERT_TRUE(ip_address.FromString("127.0.52.223")); - if (GetQuicReloadableFlag(quic_donot_write_mid_packet_processing)) { - EXPECT_QUIC_BUG( - ProcessFramesPacketWithAddresses(frames, kSelfAddress, - QuicSocketAddress(ip_address, 1000), - ENCRYPTION_FORWARD_SECURE), - "Try to write mid packet processing"); - EXPECT_EQ(1u, writer_->path_challenge_frames().size()); - // Verify stream frame will not be retransmitted. - EXPECT_TRUE(writer_->stream_frames().empty()); - } else { - ProcessFramesPacketWithAddresses(frames, kSelfAddress, - QuicSocketAddress(ip_address, 1000), - ENCRYPTION_FORWARD_SECURE); - // In prod, this would cause FAILED_TO_SERIALIZE_PACKET since the stream - // data has been freed, but simple_data_producer does not free data. - EXPECT_EQ(1u, writer_->stream_frames().size()); - } + EXPECT_QUIC_BUG(ProcessFramesPacketWithAddresses( + frames, kSelfAddress, QuicSocketAddress(ip_address, 1000), + ENCRYPTION_FORWARD_SECURE), + "Try to write mid packet processing"); + EXPECT_EQ(1u, writer_->path_challenge_frames().size()); + // Verify stream frame will not be retransmitted. + EXPECT_TRUE(writer_->stream_frames().empty()); } TEST_P(QuicConnectionTest, PtoSendStreamData) { @@ -14775,13 +15229,409 @@ TEST_P(QuicConnectionTest, PtoSendStreamData) { ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); connection_.GetRetransmissionAlarm()->Fire(); - if (GetQuicReloadableFlag(quic_preempt_stream_data_with_handshake_packet) && - GetQuicReloadableFlag(quic_donot_pto_half_rtt_data)) { - // Verify INITIAL and HANDSHAKE get retransmitted. - EXPECT_EQ(0x02020202u, writer_->final_bytes_of_last_packet()); + // Verify INITIAL and HANDSHAKE get retransmitted. + EXPECT_EQ(0x02020202u, writer_->final_bytes_of_last_packet()); +} + +TEST_P(QuicConnectionTest, SendingZeroRttPacketsDoesNotPostponePTO) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + use_tagging_decrypter(); + connection_.SetEncrypter(ENCRYPTION_INITIAL, + std::make_unique(0x01)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send CHLO. + connection_.SendCryptoStreamData(); + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + // Install 0-RTT keys. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + + // CHLO gets acknowledged after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + QuicAckFrame frame1 = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _)); + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + QuicTime pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // Send 0-RTT packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + if (GetQuicReloadableFlag( + quic_donot_rearm_pto_on_application_data_during_handshake)) { + // PTO deadline should be unchanged. + EXPECT_EQ(pto_deadline, connection_.GetRetransmissionAlarm()->deadline()); + } else { + // PTO gets re-armed. + EXPECT_NE(pto_deadline, connection_.GetRetransmissionAlarm()->deadline()); + } +} + +TEST_P(QuicConnectionTest, QueueingUndecryptablePacketsDoesntPostponePTO) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(3); + connection_.SetFromConfig(config); + use_tagging_decrypter(); + connection_.SetEncrypter(ENCRYPTION_INITIAL, + std::make_unique(0x01)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.RemoveDecrypter(ENCRYPTION_FORWARD_SECURE); + // Send CHLO. + connection_.SendCryptoStreamData(); + + // Send 0-RTT packet. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + + // CHLO gets acknowledged after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + QuicAckFrame frame1 = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _)); + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + QuicTime pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // Receive an undecryptable packets. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + peer_framer_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0xFF)); + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + // Verify PTO deadline is sooner. + EXPECT_GT(pto_deadline, connection_.GetRetransmissionAlarm()->deadline()); + pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // PTO fires. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + clock_.AdvanceTime(pto_deadline - clock_.ApproximateNow()); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // Verify PTO deadline does not change. + ProcessDataPacketAtLevel(4, !kHasStopWaiting, ENCRYPTION_FORWARD_SECURE); + EXPECT_EQ(pto_deadline, connection_.GetRetransmissionAlarm()->deadline()); +} + +TEST_P(QuicConnectionTest, QueueUndecryptableHandshakePackets) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + QuicConfig config; + config.set_max_undecryptable_packets(3); + connection_.SetFromConfig(config); + use_tagging_decrypter(); + connection_.SetEncrypter(ENCRYPTION_INITIAL, + std::make_unique(0x01)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.RemoveDecrypter(ENCRYPTION_HANDSHAKE); + // Send CHLO. + connection_.SendCryptoStreamData(); + + // Send 0-RTT packet. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + EXPECT_EQ(0u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); + + // Receive an undecryptable handshake packet. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + peer_framer_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0xFF)); + ProcessDataPacketAtLevel(3, !kHasStopWaiting, ENCRYPTION_HANDSHAKE); + // Verify this handshake packet gets queued. + EXPECT_EQ(1u, QuicConnectionPeer::NumUndecryptablePackets(&connection_)); +} + +TEST_P(QuicConnectionTest, PingNotSentAt0RTTLevelWhenInitialAvailable) { + if (!connection_.SupportsMultiplePacketNumberSpaces()) { + return; + } + use_tagging_decrypter(); + connection_.SetEncrypter(ENCRYPTION_INITIAL, + std::make_unique(0x01)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + // Send CHLO. + connection_.SendCryptoStreamData(); + // Send 0-RTT packet. + connection_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_ZERO_RTT); + connection_.SendStreamDataWithString(2, "foo", 0, NO_FIN); + + // CHLO gets acknowledged after 10ms. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(10)); + QuicAckFrame frame1 = InitAckFrame(1); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _)); + ProcessFramePacketAtLevel(1, QuicFrame(&frame1), ENCRYPTION_INITIAL); + // Verify PTO is still armed since address validation is not finished yet. + ASSERT_TRUE(connection_.GetRetransmissionAlarm()->IsSet()); + QuicTime pto_deadline = connection_.GetRetransmissionAlarm()->deadline(); + + // PTO fires. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(1); + clock_.AdvanceTime(pto_deadline - clock_.ApproximateNow()); + connection_.GetRetransmissionAlarm()->Fire(); + // Verify the PING gets sent in ENCRYPTION_INITIAL. + EXPECT_EQ(0x01010101u, writer_->final_bytes_of_last_packet()); +} + +TEST_P(QuicConnectionTest, AckElicitingFrames) { + if (!GetQuicReloadableFlag( + quic_remove_connection_migration_connection_option)) { + QuicConfig config; + config.SetConnectionOptionsToSend({kRVCM}); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + connection_.SetFromConfig(config); + } + if (!version().HasIetfQuicFrames() || + !connection_.connection_migration_use_new_cid()) { + return; + } + EXPECT_CALL(visitor_, SendNewConnectionId(_)).Times(2); + EXPECT_CALL(visitor_, OnRstStream(_)); + EXPECT_CALL(visitor_, OnWindowUpdateFrame(_)); + EXPECT_CALL(visitor_, OnBlockedFrame(_)); + EXPECT_CALL(visitor_, OnHandshakeDoneReceived()); + EXPECT_CALL(visitor_, OnStreamFrame(_)); + EXPECT_CALL(*send_algorithm_, OnCongestionEvent(_, _, _, _, _)); + EXPECT_CALL(visitor_, OnMaxStreamsFrame(_)); + EXPECT_CALL(visitor_, OnStreamsBlockedFrame(_)); + EXPECT_CALL(visitor_, OnStopSendingFrame(_)); + EXPECT_CALL(visitor_, OnMessageReceived("")); + EXPECT_CALL(visitor_, OnNewTokenReceived("")); + + SetClientConnectionId(TestConnectionId(12)); + connection_.CreateConnectionIdManager(); + QuicConnectionPeer::GetSelfIssuedConnectionIdManager(&connection_) + ->MaybeSendNewConnectionIds(); + connection_.set_can_receive_ack_frequency_frame(); + + QuicAckFrame ack_frame = InitAckFrame(1); + QuicRstStreamFrame rst_stream_frame; + QuicWindowUpdateFrame window_update_frame; + QuicPathChallengeFrame path_challenge_frame; + QuicNewConnectionIdFrame new_connection_id_frame; + QuicRetireConnectionIdFrame retire_connection_id_frame; + retire_connection_id_frame.sequence_number = 1u; + QuicStopSendingFrame stop_sending_frame; + QuicPathResponseFrame path_response_frame; + QuicMessageFrame message_frame; + QuicNewTokenFrame new_token_frame; + QuicAckFrequencyFrame ack_frequency_frame; + QuicBlockedFrame blocked_frame; + size_t packet_number = 1; + + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + + for (uint8_t i = 0; i < NUM_FRAME_TYPES; ++i) { + QuicFrameType frame_type = static_cast(i); + bool skipped = false; + QuicFrame frame; + QuicFrames frames; + // Add some padding to fullfill the min size requirement of header + // protection. + frames.push_back(QuicFrame(QuicPaddingFrame(10))); + switch (frame_type) { + case PADDING_FRAME: + frame = QuicFrame(QuicPaddingFrame(10)); + break; + case MTU_DISCOVERY_FRAME: + frame = QuicFrame(QuicMtuDiscoveryFrame()); + break; + case PING_FRAME: + frame = QuicFrame(QuicPingFrame()); + break; + case MAX_STREAMS_FRAME: + frame = QuicFrame(QuicMaxStreamsFrame()); + break; + case STOP_WAITING_FRAME: + // Not supported. + skipped = true; + break; + case STREAMS_BLOCKED_FRAME: + frame = QuicFrame(QuicStreamsBlockedFrame()); + break; + case STREAM_FRAME: + frame = QuicFrame(QuicStreamFrame()); + break; + case HANDSHAKE_DONE_FRAME: + frame = QuicFrame(QuicHandshakeDoneFrame()); + break; + case ACK_FRAME: + frame = QuicFrame(&ack_frame); + break; + case RST_STREAM_FRAME: + frame = QuicFrame(&rst_stream_frame); + break; + case CONNECTION_CLOSE_FRAME: + // Do not test connection close. + skipped = true; + break; + case GOAWAY_FRAME: + // Does not exist in IETF QUIC. + skipped = true; + break; + case BLOCKED_FRAME: + frame = QuicFrame(&blocked_frame); + break; + case WINDOW_UPDATE_FRAME: + frame = QuicFrame(&window_update_frame); + break; + case PATH_CHALLENGE_FRAME: + frame = QuicFrame(&path_challenge_frame); + break; + case STOP_SENDING_FRAME: + frame = QuicFrame(&stop_sending_frame); + break; + case NEW_CONNECTION_ID_FRAME: + frame = QuicFrame(&new_connection_id_frame); + break; + case RETIRE_CONNECTION_ID_FRAME: + frame = QuicFrame(&retire_connection_id_frame); + break; + case PATH_RESPONSE_FRAME: + frame = QuicFrame(&path_response_frame); + break; + case MESSAGE_FRAME: + frame = QuicFrame(&message_frame); + break; + case CRYPTO_FRAME: + // CRYPTO_FRAME is ack eliciting is covered by other tests. + skipped = true; + break; + case NEW_TOKEN_FRAME: + frame = QuicFrame(&new_token_frame); + break; + case ACK_FREQUENCY_FRAME: + frame = QuicFrame(&ack_frequency_frame); + break; + case NUM_FRAME_TYPES: + skipped = true; + break; + } + if (skipped) { + continue; + } + ASSERT_EQ(frame_type, frame.type); + frames.push_back(frame); + EXPECT_FALSE(connection_.HasPendingAcks()); + // Process frame. + ProcessFramesPacketAtLevel(packet_number++, frames, + ENCRYPTION_FORWARD_SECURE); + if (QuicUtils::IsAckElicitingFrame(frame_type)) { + ASSERT_TRUE(connection_.HasPendingAcks()) << frame; + // Flush ACK. + clock_.AdvanceTime(DefaultDelayedAckTime()); + connection_.GetAckAlarm()->Fire(); + } + EXPECT_FALSE(connection_.HasPendingAcks()); + ASSERT_TRUE(connection_.connected()); + } +} + +// Regression test for b/201643321. +TEST_P(QuicConnectionTest, FailedToRetransmitShlo) { + if (!version().HasIetfQuicFrames()) { + return; + } + set_perspective(Perspective::IS_SERVER); + EXPECT_CALL(visitor_, OnCryptoFrame(_)).Times(AnyNumber()); + EXPECT_CALL(visitor_, OnStreamFrame(_)).Times(AnyNumber()); + use_tagging_decrypter(); + // Received INITIAL 1. + ProcessCryptoPacketAtLevel(1, ENCRYPTION_INITIAL); + EXPECT_TRUE(connection_.HasPendingAcks()); + + peer_framer_.SetEncrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + + connection_.SetEncrypter(ENCRYPTION_INITIAL, + std::make_unique(0x01)); + connection_.SetEncrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x03)); + SetDecrypter(ENCRYPTION_HANDSHAKE, + std::make_unique(0x02)); + SetDecrypter(ENCRYPTION_ZERO_RTT, + std::make_unique(0x02)); + connection_.SetEncrypter(ENCRYPTION_FORWARD_SECURE, + std::make_unique(0x04)); + // Received ENCRYPTION_ZERO_RTT 1. + ProcessDataPacketAtLevel(1, !kHasStopWaiting, ENCRYPTION_ZERO_RTT); + { + QuicConnection::ScopedPacketFlusher flusher(&connection_); + // Send INITIAL 1. + connection_.SetDefaultEncryptionLevel(ENCRYPTION_INITIAL); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_INITIAL); + // Send HANDSHAKE 2. + EXPECT_CALL(visitor_, OnHandshakePacketSent()).Times(1); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_HANDSHAKE); + connection_.SendCryptoDataWithString("foo", 0, ENCRYPTION_HANDSHAKE); + connection_.SetDefaultEncryptionLevel(ENCRYPTION_FORWARD_SECURE); + // Send half RTT data to exhaust amplification credit. + connection_.SendStreamDataWithString(0, std::string(100 * 1024, 'a'), 0, + NO_FIN); + } + // Received INITIAL 2. + ProcessCryptoPacketAtLevel(2, ENCRYPTION_INITIAL); + ASSERT_TRUE(connection_.HasPendingAcks()); + // Verify ACK delay is 1ms. + EXPECT_EQ(clock_.Now() + kAlarmGranularity, + connection_.GetAckAlarm()->deadline()); + if (!GetQuicReloadableFlag( + quic_donot_check_amplification_limit_with_pending_timer_credit)) { + // ACK is not sent because of amplification limit throttled. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(0); } else { - // Application data preempts handshake data when PTO fires. + // ACK is not throttled by amplification limit, and SHLO is bundled. Also + // HANDSHAKE packet gets coalesced. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(2); + } + // ACK alarm fires. + clock_.AdvanceTime(kAlarmGranularity); + connection_.GetAckAlarm()->Fire(); + if (GetQuicReloadableFlag( + quic_donot_check_amplification_limit_with_pending_timer_credit)) { + // Verify HANDSHAKE packet is coalesced with INITIAL ACK + SHLO. EXPECT_EQ(0x03030303u, writer_->final_bytes_of_last_packet()); + // Only the first packet in the coalesced packet has been processed, + // verify SHLO is bundled with INITIAL ACK. + EXPECT_EQ(1u, writer_->ack_frames().size()); + EXPECT_EQ(1u, writer_->crypto_frames().size()); + // Process the coalesced HANDSHAKE packet. + ASSERT_TRUE(writer_->coalesced_packet() != nullptr); + auto packet = writer_->coalesced_packet()->Clone(); + writer_->framer()->ProcessPacket(*packet); + EXPECT_EQ(0u, writer_->ack_frames().size()); + EXPECT_EQ(1u, writer_->crypto_frames().size()); + ASSERT_TRUE(writer_->coalesced_packet() == nullptr); + } + + // Received INITIAL 3. + EXPECT_CALL(*send_algorithm_, OnPacketSent(_, _, _, _, _)).Times(AnyNumber()); + ProcessCryptoPacketAtLevel(3, ENCRYPTION_INITIAL); + if (!GetQuicReloadableFlag( + quic_donot_check_amplification_limit_with_pending_timer_credit)) { + EXPECT_FALSE(connection_.HasPendingAcks()); + } else { + EXPECT_TRUE(connection_.HasPendingAcks()); } } diff --git a/gquiche/quic/core/quic_constants.h b/gquiche/quic/core/quic_constants.h index eb488496..5bdff886 100644 --- a/gquiche/quic/core/quic_constants.h +++ b/gquiche/quic/core/quic_constants.h @@ -28,7 +28,7 @@ const uint64_t kNumMicrosPerSecond = kNumMicrosPerMilli * kNumMillisPerSecond; // Default number of connections for N-connection emulation. const uint32_t kDefaultNumConnections = 2; // Default initial maximum size in bytes of a QUIC packet. -const QuicByteCount kDefaultMaxPacketSize = 1350; +const QuicByteCount kDefaultMaxPacketSize = 1250; // Default initial maximum size in bytes of a QUIC packet for servers. const QuicByteCount kDefaultServerMaxPacketSize = 1000; // Maximum transmission unit on Ethernet. @@ -295,10 +295,14 @@ QUIC_EXPORT_PRIVATE extern const char* const kEPIDGoogleFrontEnd; QUIC_EXPORT_PRIVATE extern const char* const kEPIDGoogleFrontEnd0; // HTTP/3 Datagrams. -enum : QuicDatagramFlowId { - kFirstDatagramFlowIdClient = 0, - kFirstDatagramFlowIdServer = 1, - kDatagramFlowIdIncrement = 2, +enum : QuicDatagramContextId { + kFirstDatagramContextIdClient = 0, + kFirstDatagramContextIdServer = 1, + kDatagramContextIdIncrement = 2, +}; + +enum : uint64_t { + kHttpDatagramStreamIdDivisor = 4, }; } // namespace quic diff --git a/gquiche/quic/core/quic_control_frame_manager.cc b/gquiche/quic/core/quic_control_frame_manager.cc index ed6e759f..543b19d7 100644 --- a/gquiche/quic/core/quic_control_frame_manager.cc +++ b/gquiche/quic/core/quic_control_frame_manager.cc @@ -17,7 +17,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" -#include "gquiche/quic/platform/api/quic_map_util.h" namespace quic { @@ -60,8 +59,7 @@ void QuicControlFrameManager::WriteOrBufferQuicFrame(QuicFrame frame) { } void QuicControlFrameManager::WriteOrBufferRstStream( - QuicStreamId id, - QuicRstStreamErrorCode error, + QuicStreamId id, QuicResetStreamError error, QuicStreamOffset bytes_written) { QUIC_DVLOG(1) << "Writing RST_STREAM_FRAME"; WriteOrBufferQuicFrame((QuicFrame(new QuicRstStreamFrame( @@ -69,8 +67,7 @@ void QuicControlFrameManager::WriteOrBufferRstStream( } void QuicControlFrameManager::WriteOrBufferGoAway( - QuicErrorCode error, - QuicStreamId last_good_stream_id, + QuicErrorCode error, QuicStreamId last_good_stream_id, const std::string& reason) { QUIC_DVLOG(1) << "Writing GOAWAY_FRAME"; WriteOrBufferQuicFrame(QuicFrame(new QuicGoAwayFrame( @@ -78,8 +75,7 @@ void QuicControlFrameManager::WriteOrBufferGoAway( } void QuicControlFrameManager::WriteOrBufferWindowUpdate( - QuicStreamId id, - QuicStreamOffset byte_offset) { + QuicStreamId id, QuicStreamOffset byte_offset) { QUIC_DVLOG(1) << "Writing WINDOW_UPDATE_FRAME"; WriteOrBufferQuicFrame(QuicFrame( new QuicWindowUpdateFrame(++last_control_frame_id_, id, byte_offset))); @@ -108,11 +104,10 @@ void QuicControlFrameManager::WriteOrBufferMaxStreams(QuicStreamCount count, } void QuicControlFrameManager::WriteOrBufferStopSending( - QuicRstStreamErrorCode code, - QuicStreamId stream_id) { + QuicResetStreamError error, QuicStreamId stream_id) { QUIC_DVLOG(1) << "Writing STOP_SENDING_FRAME"; WriteOrBufferQuicFrame(QuicFrame( - new QuicStopSendingFrame(++last_control_frame_id_, stream_id, code))); + new QuicStopSendingFrame(++last_control_frame_id_, stream_id, error))); } void QuicControlFrameManager::WriteOrBufferHandshakeDone() { @@ -135,8 +130,7 @@ void QuicControlFrameManager::WriteOrBufferAckFrequency( } void QuicControlFrameManager::WriteOrBufferNewConnectionId( - const QuicConnectionId& connection_id, - uint64_t sequence_number, + const QuicConnectionId& connection_id, uint64_t sequence_number, uint64_t retire_prior_to, const StatelessResetToken& stateless_reset_token) { QUIC_DVLOG(1) << "Writing NEW_CONNECTION_ID frame"; @@ -167,14 +161,14 @@ void QuicControlFrameManager::OnControlFrameSent(const QuicFrame& frame) { } if (frame.type == WINDOW_UPDATE_FRAME) { QuicStreamId stream_id = frame.window_update_frame->stream_id; - if (QuicContainsKey(window_update_frames_, stream_id) && + if (window_update_frames_.contains(stream_id) && id > window_update_frames_[stream_id]) { // Consider the older window update of the same stream as acked. OnControlFrameIdAcked(window_update_frames_[stream_id]); } window_update_frames_[stream_id] = id; } - if (QuicContainsKey(pending_retransmissions_, id)) { + if (pending_retransmissions_.contains(id)) { // This is retransmitted control frame. pending_retransmissions_.erase(id); return; @@ -197,7 +191,7 @@ bool QuicControlFrameManager::OnControlFrameAcked(const QuicFrame& frame) { } if (frame.type == WINDOW_UPDATE_FRAME) { QuicStreamId stream_id = frame.window_update_frame->stream_id; - if (QuicContainsKey(window_update_frames_, stream_id) && + if (window_update_frames_.contains(stream_id) && window_update_frames_[stream_id] == id) { window_update_frames_.erase(stream_id); } @@ -223,7 +217,7 @@ void QuicControlFrameManager::OnControlFrameLost(const QuicFrame& frame) { // This frame has already been acked. return; } - if (!QuicContainsKey(pending_retransmissions_, id)) { + if (!pending_retransmissions_.contains(id)) { pending_retransmissions_[id] = true; QUIC_BUG_IF(quic_bug_12727_2, pending_retransmissions_.size() > control_frames_.size()) diff --git a/gquiche/quic/core/quic_control_frame_manager.h b/gquiche/quic/core/quic_control_frame_manager.h index 8ea7da51..519f16df 100644 --- a/gquiche/quic/core/quic_control_frame_manager.h +++ b/gquiche/quic/core/quic_control_frame_manager.h @@ -8,10 +8,13 @@ #include #include +#include "absl/container/flat_hash_map.h" #include "gquiche/quic/core/frames/quic_frame.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_connection_id.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_types.h" +#include "gquiche/common/quiche_circular_deque.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -51,8 +54,7 @@ class QUIC_EXPORT_PRIVATE QuicControlFrameManager { // Tries to send a WINDOW_UPDATE_FRAME. Buffers the frame if it cannot be sent // immediately. - void WriteOrBufferRstStream(QuicControlFrameId id, - QuicRstStreamErrorCode error, + void WriteOrBufferRstStream(QuicControlFrameId id, QuicResetStreamError error, QuicStreamOffset bytes_written); // Tries to send a GOAWAY_FRAME. Buffers the frame if it cannot be sent @@ -79,7 +81,7 @@ class QUIC_EXPORT_PRIVATE QuicControlFrameManager { // Tries to send an IETF-QUIC STOP_SENDING frame. The frame is buffered if it // can not be sent immediately. - void WriteOrBufferStopSending(QuicRstStreamErrorCode code, + void WriteOrBufferStopSending(QuicResetStreamError error, QuicStreamId stream_id); // Tries to send an HANDSHAKE_DONE frame. The frame is buffered if it can not @@ -94,8 +96,7 @@ class QUIC_EXPORT_PRIVATE QuicControlFrameManager { // Tries to send a NEW_CONNECTION_ID frame. The frame is buffered if it cannot // be sent immediately. void WriteOrBufferNewConnectionId( - const QuicConnectionId& connection_id, - uint64_t sequence_number, + const QuicConnectionId& connection_id, uint64_t sequence_number, uint64_t retire_prior_to, const StatelessResetToken& stateless_reset_token); @@ -163,7 +164,7 @@ class QUIC_EXPORT_PRIVATE QuicControlFrameManager { // frame. void WriteOrBufferQuicFrame(QuicFrame frame); - QuicCircularDeque control_frames_; + quiche::QuicheCircularDeque control_frames_; // Id of latest saved control frame. 0 if no control frame has been saved. QuicControlFrameId last_control_frame_id_; @@ -177,12 +178,13 @@ class QUIC_EXPORT_PRIVATE QuicControlFrameManager { // TODO(fayang): switch to linked_hash_set when chromium supports it. The bool // is not used here. // Lost control frames waiting to be retransmitted. - QuicLinkedHashMap pending_retransmissions_; + quiche::QuicheLinkedHashMap + pending_retransmissions_; DelegateInterface* delegate_; // Last sent window update frame for each stream. - QuicSmallMap window_update_frames_; + absl::flat_hash_map window_update_frames_; }; } // namespace quic diff --git a/gquiche/quic/core/quic_control_frame_manager_test.cc b/gquiche/quic/core/quic_control_frame_manager_test.cc index 649ca67e..4e31ae45 100644 --- a/gquiche/quic/core/quic_control_frame_manager_test.cc +++ b/gquiche/quic/core/quic_control_frame_manager_test.cc @@ -66,12 +66,16 @@ class QuicControlFrameManagerTest : public QuicTest { EXPECT_FALSE(manager_->WillingToWrite()); EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); - manager_->WriteOrBufferRstStream(kTestStreamId, QUIC_STREAM_CANCELLED, 0); + manager_->WriteOrBufferRstStream( + kTestStreamId, + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), 0); manager_->WriteOrBufferGoAway(QUIC_PEER_GOING_AWAY, kTestStreamId, "Going away."); manager_->WriteOrBufferWindowUpdate(kTestStreamId, 100); manager_->WriteOrBufferBlocked(kTestStreamId); - manager_->WriteOrBufferStopSending(kTestStopSendingCode, kTestStreamId); + manager_->WriteOrBufferStopSending( + QuicResetStreamError::FromInternal(kTestStopSendingCode), + kTestStreamId); number_of_frames_ = 5u; EXPECT_EQ(number_of_frames_, QuicControlFrameManagerPeer::QueueSize(manager_.get())); @@ -341,14 +345,18 @@ TEST_F(QuicControlFrameManagerTest, TooManyBufferedControlFrames) { // Write 995 control frames. EXPECT_CALL(*session_, WriteControlFrame(_, _)).WillOnce(Return(false)); for (size_t i = 0; i < 995; ++i) { - manager_->WriteOrBufferRstStream(kTestStreamId, QUIC_STREAM_CANCELLED, 0); + manager_->WriteOrBufferRstStream( + kTestStreamId, + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), 0); } // Verify write one more control frame causes connection close. EXPECT_CALL( *connection_, CloseConnection(QUIC_TOO_MANY_BUFFERED_CONTROL_FRAMES, _, ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET)); - manager_->WriteOrBufferRstStream(kTestStreamId, QUIC_STREAM_CANCELLED, 0); + manager_->WriteOrBufferRstStream( + kTestStreamId, QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), + 0); } } // namespace diff --git a/gquiche/quic/core/quic_crypto_client_handshaker.h b/gquiche/quic/core/quic_crypto_client_handshaker.h index 7e0fee97..359553bd 100644 --- a/gquiche/quic/core/quic_crypto_client_handshaker.h +++ b/gquiche/quic/core/quic_crypto_client_handshaker.h @@ -65,6 +65,13 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientHandshaker std::unique_ptr /*application_state*/) override { QUICHE_NOTREACHED(); } + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + QUICHE_NOTREACHED(); + return false; + } // From QuicCryptoHandshaker void OnHandshakeMessage(const CryptoHandshakeMessage& message) override; diff --git a/gquiche/quic/core/quic_crypto_client_handshaker_test.cc b/gquiche/quic/core/quic_crypto_client_handshaker_test.cc index c83a1974..a81d7be0 100644 --- a/gquiche/quic/core/quic_crypto_client_handshaker_test.cc +++ b/gquiche/quic/core/quic_crypto_client_handshaker_test.cc @@ -77,18 +77,20 @@ class DummyProofSource : public ProofSource { QuicTransportVersion /*transport_version*/, absl::string_view /*chlo_hash*/, std::unique_ptr callback) override { - QuicReferenceCountedPointer chain = - GetCertChain(server_address, client_address, hostname); + bool cert_matched_sni; + QuicReferenceCountedPointer chain = GetCertChain( + server_address, client_address, hostname, &cert_matched_sni); QuicCryptoProof proof; proof.signature = "Dummy signature"; proof.leaf_cert_scts = "Dummy timestamp"; + proof.cert_matched_sni = cert_matched_sni; callback->Run(true, chain, proof, /*details=*/nullptr); } QuicReferenceCountedPointer GetCertChain( const QuicSocketAddress& /*server_address*/, const QuicSocketAddress& /*client_address*/, - const std::string& /*hostname*/) override { + const std::string& /*hostname*/, bool* /*cert_matched_sni*/) override { std::vector certs; certs.push_back("Dummy cert"); return QuicReferenceCountedPointer( @@ -105,6 +107,11 @@ class DummyProofSource : public ProofSource { callback->Run(true, "Dummy signature", /*details=*/nullptr); } + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override { + return {}; + } + TicketCrypter* GetTicketCrypter() override { return nullptr; } }; diff --git a/gquiche/quic/core/quic_crypto_client_stream.cc b/gquiche/quic/core/quic_crypto_client_stream.cc index 18b67e76..ab04d076 100644 --- a/gquiche/quic/core/quic_crypto_client_stream.cc +++ b/gquiche/quic/core/quic_crypto_client_stream.cc @@ -43,11 +43,14 @@ QuicCryptoClientStream::QuicCryptoClientStream( server_id, this, session, std::move(verify_context), crypto_config, proof_handler); break; - case PROTOCOL_TLS1_3: - handshaker_ = std::make_unique( + case PROTOCOL_TLS1_3: { + auto handshaker = std::make_unique( server_id, this, session, std::move(verify_context), crypto_config, proof_handler, has_application_state); + tls_handshaker_ = handshaker.get(); + handshaker_ = std::move(handshaker); break; + } case PROTOCOL_UNSUPPORTED: QUIC_BUG(quic_bug_10296_1) << "Attempting to create QuicCryptoClientStream for unknown " @@ -125,6 +128,13 @@ QuicCryptoClientStream::CreateCurrentOneRttEncrypter() { return handshaker_->CreateCurrentOneRttEncrypter(); } +bool QuicCryptoClientStream::ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + return handshaker_->ExportKeyingMaterial(label, context, result_len, result); +} + std::string QuicCryptoClientStream::chlo_hash() const { return handshaker_->chlo_hash(); } @@ -150,21 +160,14 @@ void QuicCryptoClientStream::OnNewTokenReceived(absl::string_view token) { handshaker_->OnNewTokenReceived(token); } -std::string QuicCryptoClientStream::GetAddressToken() const { - QUICHE_DCHECK(false); - return ""; -} - -bool QuicCryptoClientStream::ValidateAddressToken( - absl::string_view /*token*/) const { - QUICHE_DCHECK(false); - return false; -} - void QuicCryptoClientStream::SetServerApplicationStateForResumption( std::unique_ptr application_state) { handshaker_->SetServerApplicationStateForResumption( std::move(application_state)); } +SSL* QuicCryptoClientStream::GetSsl() const { + return tls_handshaker_ == nullptr ? nullptr : tls_handshaker_->ssl(); +} + } // namespace quic diff --git a/gquiche/quic/core/quic_crypto_client_stream.h b/gquiche/quic/core/quic_crypto_client_stream.h index 81c40337..277bfe91 100644 --- a/gquiche/quic/core/quic_crypto_client_stream.h +++ b/gquiche/quic/core/quic_crypto_client_stream.h @@ -11,6 +11,7 @@ #include "gquiche/quic/core/crypto/proof_verifier.h" #include "gquiche/quic/core/crypto/quic_crypto_client_config.h" +#include "gquiche/quic/core/proto/cached_network_parameters_proto.h" #include "gquiche/quic/core/quic_config.h" #include "gquiche/quic/core/quic_crypto_handshaker.h" #include "gquiche/quic/core/quic_crypto_stream.h" @@ -25,6 +26,8 @@ namespace test { class QuicCryptoClientStreamPeer; } // namespace test +class TlsClientHandshaker; + class QUIC_EXPORT_PRIVATE QuicCryptoClientStreamBase : public QuicCryptoStream { public: explicit QuicCryptoClientStreamBase(QuicSession* session); @@ -63,6 +66,35 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientStreamBase : public QuicCryptoStream { // client. Does not count update messages that were received prior // to handshake confirmation. virtual int num_scup_messages_received() const = 0; + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + QUICHE_NOTREACHED(); + return false; + } + + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_params*/) const override { + QUICHE_DCHECK(false); + return ""; + } + + bool ValidateAddressToken(absl::string_view /*token*/) const override { + QUICHE_DCHECK(false); + return false; + } + + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + QUICHE_DCHECK(false); + return nullptr; + } + + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override { + QUICHE_DCHECK(false); + } }; class QUIC_EXPORT_PRIVATE QuicCryptoClientStream @@ -185,6 +217,13 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientStream // Called when application state is received. virtual void SetServerApplicationStateForResumption( std::unique_ptr application_state) = 0; + + // Called to obtain keying material export of length |result_len| with the + // given |label| and |context|. Returns false on failure. + virtual bool ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) = 0; }; // ProofHandler is an interface that handles callbacks from the crypto @@ -240,8 +279,6 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientStream ConnectionCloseSource source) override; void OnHandshakeDoneReceived() override; void OnNewTokenReceived(absl::string_view token) override; - std::string GetAddressToken() const override; - bool ValidateAddressToken(absl::string_view token) const override; HandshakeState GetHandshakeState() const override; void SetServerApplicationStateForResumption( std::unique_ptr application_state) override; @@ -250,7 +287,9 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientStream std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() override; std::unique_ptr CreateCurrentOneRttEncrypter() override; - + SSL* GetSsl() const override; + bool ExportKeyingMaterial(absl::string_view label, absl::string_view context, + size_t result_len, std::string* result) override; std::string chlo_hash() const; protected: @@ -261,6 +300,10 @@ class QUIC_EXPORT_PRIVATE QuicCryptoClientStream private: friend class test::QuicCryptoClientStreamPeer; std::unique_ptr handshaker_; + // Points to |handshaker_| if it uses TLS1.3. Otherwise, nullptr. + // TODO(danzh) change the type of |handshaker_| to TlsClientHandshaker after + // deprecating Google QUIC. + TlsClientHandshaker* tls_handshaker_{nullptr}; }; } // namespace quic diff --git a/gquiche/quic/core/quic_crypto_handshaker.cc b/gquiche/quic/core/quic_crypto_handshaker.cc index 664d4410..4f8afa07 100644 --- a/gquiche/quic/core/quic_crypto_handshaker.cc +++ b/gquiche/quic/core/quic_crypto_handshaker.cc @@ -27,10 +27,7 @@ void QuicCryptoHandshaker::SendHandshakeMessage( session()->OnCryptoHandshakeMessageSent(message); last_sent_handshake_message_tag_ = message.tag(); const QuicData& data = message.GetSerialized(); - stream_->WriteCryptoData(session_->use_write_or_buffer_data_at_level() - ? level - : session_->connection()->encryption_level(), - data.AsStringPiece()); + stream_->WriteCryptoData(level, data.AsStringPiece()); } void QuicCryptoHandshaker::OnError(CryptoFramer* framer) { diff --git a/gquiche/quic/core/quic_crypto_server_stream.cc b/gquiche/quic/core/quic_crypto_server_stream.cc index 1911032a..d1a00d26 100644 --- a/gquiche/quic/core/quic_crypto_server_stream.cc +++ b/gquiche/quic/core/quic_crypto_server_stream.cc @@ -11,7 +11,8 @@ #include "absl/strings/string_view.h" #include "openssl/sha.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/quic/platform/api/quic_testvalue.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { @@ -164,6 +165,21 @@ void QuicCryptoServerStream:: process_client_hello_cb_ = nullptr; proof_source_details_ = std::move(proof_source_details); + AdjustTestValue("quic::QuicCryptoServerStream::after_process_client_hello", + session()); + + if (noop_if_disconnected_after_process_chlo_) { + QUIC_RELOADABLE_FLAG_COUNT( + quic_crypto_noop_if_disconnected_after_process_chlo); + if (!session()->connection()->connected()) { + QUIC_CODE_COUNT(quic_crypto_disconnected_after_process_client_hello); + QUIC_LOG_FIRST_N(INFO, 10) + << "After processing CHLO, QUIC connection has been closed with code " + << session()->error() << ", details: " << session()->error_details(); + return; + } + } + const CryptoHandshakeMessage& message = result.client_hello; if (error != QUIC_NO_ERROR) { OnUnrecoverableError(error, error_details); @@ -287,15 +303,8 @@ void QuicCryptoServerStream::FinishSendServerConfigUpdate( QUIC_DVLOG(1) << "Server: Sending server config update: " << message.DebugString(); - if (!session()->use_write_or_buffer_data_at_level() && - !QuicVersionUsesCryptoFrames(transport_version())) { - const QuicData& data = message.GetSerialized(); - WriteOrBufferData(absl::string_view(data.data(), data.length()), false, - nullptr); - } else { - // Send server config update in ENCRYPTION_FORWARD_SECURE. - SendHandshakeMessage(message, ENCRYPTION_FORWARD_SECURE); - } + // Send server config update in ENCRYPTION_FORWARD_SECURE. + SendHandshakeMessage(message, ENCRYPTION_FORWARD_SECURE); ++num_server_config_update_messages_sent_; } @@ -344,7 +353,8 @@ void QuicCryptoServerStream::OnNewTokenReceived(absl::string_view /*token*/) { QUICHE_DCHECK(false); } -std::string QuicCryptoServerStream::GetAddressToken() const { +std::string QuicCryptoServerStream::GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) const { QUICHE_DCHECK(false); return ""; } @@ -359,6 +369,10 @@ bool QuicCryptoServerStream::ShouldSendExpectCTHeader() const { return signed_config_->proof.send_expect_ct_header; } +bool QuicCryptoServerStream::DidCertMatchSni() const { + return signed_config_->proof.cert_matched_sni; +} + const ProofSource::Details* QuicCryptoServerStream::ProofSourceDetails() const { return proof_source_details_.get(); } @@ -514,4 +528,6 @@ const QuicSocketAddress QuicCryptoServerStream::GetClientAddress() { return session()->connection()->peer_address(); } +SSL* QuicCryptoServerStream::GetSsl() const { return nullptr; } + } // namespace quic diff --git a/gquiche/quic/core/quic_crypto_server_stream.h b/gquiche/quic/core/quic_crypto_server_stream.h index 5c37cbb2..6ba2ddf1 100644 --- a/gquiche/quic/core/quic_crypto_server_stream.h +++ b/gquiche/quic/core/quic_crypto_server_stream.h @@ -48,9 +48,11 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerStream ConnectionCloseSource /*source*/) override {} void OnHandshakeDoneReceived() override; void OnNewTokenReceived(absl::string_view token) override; - std::string GetAddressToken() const override; + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_params*/) const override; bool ValidateAddressToken(absl::string_view token) const override; bool ShouldSendExpectCTHeader() const override; + bool DidCertMatchSni() const override; const ProofSource::Details* ProofSourceDetails() const override; // From QuicCryptoStream @@ -68,6 +70,7 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerStream std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter() override; std::unique_ptr CreateCurrentOneRttEncrypter() override; + SSL* GetSsl() const override; // From QuicCryptoHandshaker void OnHandshakeMessage(const CryptoHandshakeMessage& message) override; @@ -252,6 +255,8 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerStream bool encryption_established_; bool one_rtt_keys_available_; bool one_rtt_packet_decrypted_; + const bool noop_if_disconnected_after_process_chlo_ = GetQuicReloadableFlag( + quic_crypto_noop_if_disconnected_after_process_chlo); QuicReferenceCountedPointer crypto_negotiated_params_; }; diff --git a/gquiche/quic/core/quic_crypto_server_stream_base.h b/gquiche/quic/core/quic_crypto_server_stream_base.h index 4c1eccf0..7b7f0e45 100644 --- a/gquiche/quic/core/quic_crypto_server_stream_base.h +++ b/gquiche/quic/core/quic_crypto_server_stream_base.h @@ -73,11 +73,6 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerStreamBase : public QuicCryptoStream { // the resumption actually occurred. virtual bool ResumptionAttempted() const = 0; - virtual const CachedNetworkParameters* PreviousCachedNetworkParams() - const = 0; - virtual void SetPreviousCachedNetworkParams( - CachedNetworkParameters cached_network_params) = 0; - // NOTE: Indicating that the Expect-CT header should be sent here presents // a layering violation to some extent. The Expect-CT header only applies to // HTTP connections, while this class can be used for non-HTTP applications. @@ -85,11 +80,22 @@ class QUIC_EXPORT_PRIVATE QuicCryptoServerStreamBase : public QuicCryptoStream { // configuration for the certificate used in the connection is accessible. virtual bool ShouldSendExpectCTHeader() const = 0; + // Return true if a cert was picked that matched the SNI hostname. + virtual bool DidCertMatchSni() const = 0; + // Returns the Details from the latest call to ProofSource::GetProof or // ProofSource::ComputeTlsSignature. Returns nullptr if no such call has been // made. The Details are owned by the QuicCryptoServerStreamBase and the // pointer is only valid while the owning object is still valid. virtual const ProofSource::Details* ProofSourceDetails() const = 0; + + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + QUICHE_NOTREACHED(); + return false; + } }; // Creates an appropriate QuicCryptoServerStream for the provided parameters, diff --git a/gquiche/quic/core/quic_crypto_server_stream_test.cc b/gquiche/quic/core/quic_crypto_server_stream_test.cc index 5b88ad7e..4361db71 100644 --- a/gquiche/quic/core/quic_crypto_server_stream_test.cc +++ b/gquiche/quic/core/quic_crypto_server_stream_test.cc @@ -354,8 +354,7 @@ class QuicCryptoServerStreamTestWithFakeProofSource }; // Regression test for b/35422225, in which multiple CHLOs arriving on the same -// connection in close succession could cause a crash, especially when the use -// of Mentat signing meant that it took a while for each CHLO to be processed. +// connection in close succession could cause a crash. TEST_F(QuicCryptoServerStreamTestWithFakeProofSource, MultipleChlo) { Initialize(); GetFakeProofSource()->Activate(); diff --git a/gquiche/quic/core/quic_crypto_stream.cc b/gquiche/quic/core/quic_crypto_stream.cc index 08b866a5..02a81b73 100644 --- a/gquiche/quic/core/quic_crypto_stream.cc +++ b/gquiche/quic/core/quic_crypto_stream.cc @@ -127,32 +127,11 @@ void QuicCryptoStream::OnDataAvailableInSequencer( } } -bool QuicCryptoStream::ExportKeyingMaterial(absl::string_view label, - absl::string_view context, - size_t result_len, - std::string* result) const { - if (!one_rtt_keys_available()) { - QUIC_DLOG(ERROR) << "ExportKeyingMaterial was called before forward-secure" - << "encryption was established."; - return false; - } - return CryptoUtils::ExportKeyingMaterial( - crypto_negotiated_params().subkey_secret, label, context, result_len, - result); -} - void QuicCryptoStream::WriteCryptoData(EncryptionLevel level, absl::string_view data) { if (!QuicVersionUsesCryptoFrames(session()->transport_version())) { - if (session()->use_write_or_buffer_data_at_level()) { - WriteOrBufferDataAtLevel(data, /*fin=*/false, level, - /*ack_listener=*/nullptr); - return; - } - // The QUIC crypto handshake takes care of setting the appropriate - // encryption level before writing data. Since that is the only handshake - // supported in versions less than 47, |level| can be ignored here. - WriteOrBufferData(data, /* fin */ false, /* ack_listener */ nullptr); + WriteOrBufferDataAtLevel(data, /*fin=*/false, level, + /*ack_listener=*/nullptr); return; } if (data.empty()) { diff --git a/gquiche/quic/core/quic_crypto_stream.h b/gquiche/quic/core/quic_crypto_stream.h index 16aba19b..04e3b821 100644 --- a/gquiche/quic/core/quic_crypto_stream.h +++ b/gquiche/quic/core/quic_crypto_stream.h @@ -13,6 +13,7 @@ #include "openssl/ssl.h" #include "gquiche/quic/core/crypto/crypto_framer.h" #include "gquiche/quic/core/crypto/crypto_utils.h" +#include "gquiche/quic/core/proto/cached_network_parameters_proto.h" #include "gquiche/quic/core/quic_config.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_stream.h" @@ -21,6 +22,7 @@ namespace quic { +class CachedNetworkParameters; class QuicSession; // Crypto handshake messages in QUIC take place over a reserved stream with the @@ -63,11 +65,12 @@ class QUIC_EXPORT_PRIVATE QuicCryptoStream : public QuicStream { // Performs key extraction to derive a new secret of |result_len| bytes // dependent on |label|, |context|, and the stream's negotiated subkey secret. // Returns false if the handshake has not been confirmed or the parameters are - // invalid (e.g. |label| contains null bytes); returns true on success. - bool ExportKeyingMaterial(absl::string_view label, - absl::string_view context, - size_t result_len, - std::string* result) const; + // invalid (e.g. |label| contains null bytes); returns true on success. This + // method is only supported for IETF QUIC and MUST NOT be called in gQUIC as + // that'll trigger an assert in DEBUG build. + virtual bool ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, std::string* result) = 0; // Writes |data| to the QuicStream at level |level|. virtual void WriteCryptoData(EncryptionLevel level, absl::string_view data); @@ -108,11 +111,22 @@ class QUIC_EXPORT_PRIVATE QuicCryptoStream : public QuicStream { virtual void OnNewTokenReceived(absl::string_view token) = 0; // Called to get an address token. - virtual std::string GetAddressToken() const = 0; + virtual std::string GetAddressToken( + const CachedNetworkParameters* cached_network_params) const = 0; // Called to validate |token|. virtual bool ValidateAddressToken(absl::string_view token) const = 0; + // Get the last CachedNetworkParameters received from a valid address token. + virtual const CachedNetworkParameters* PreviousCachedNetworkParams() + const = 0; + + // Set the CachedNetworkParameters that will be returned by + // PreviousCachedNetworkParams. + // TODO(wub): This function is test only, move it to a test only library. + virtual void SetPreviousCachedNetworkParams( + CachedNetworkParameters cached_network_params) = 0; + // Returns current handshake state. virtual HandshakeState GetHandshakeState() const = 0; @@ -147,6 +161,11 @@ class QUIC_EXPORT_PRIVATE QuicCryptoStream : public QuicStream { // decrypter returned by AdvanceKeysAndCreateCurrentOneRttDecrypter(). virtual std::unique_ptr CreateCurrentOneRttEncrypter() = 0; + // Return the SSL struct object created by BoringSSL if the stream is using + // TLS1.3. Otherwise, return nullptr. + // This method is used in Envoy. + virtual SSL* GetSsl() const = 0; + // Called to cancel retransmission of unencrypted crypto stream data. void NeuterUnencryptedStreamData(); @@ -220,6 +239,12 @@ class QUIC_EXPORT_PRIVATE QuicCryptoStream : public QuicStream { // data, and false if all data has been acked. bool IsWaitingForAcks() const; + // Helper method for OnDataAvailable. Calls CryptoMessageParser::ProcessInput + // with the data available in |sequencer| and |level|, and marks the data + // passed to ProcessInput as consumed. + virtual void OnDataAvailableInSequencer(QuicStreamSequencer* sequencer, + EncryptionLevel level); + private: // Data sent and received in CRYPTO frames is sent at multiple encryption // levels. Some of the state for the single logical crypto stream is split @@ -232,12 +257,6 @@ class QUIC_EXPORT_PRIVATE QuicCryptoStream : public QuicStream { QuicStreamSendBuffer send_buffer; }; - // Helper method for OnDataAvailable. Calls CryptoMessageParser::ProcessInput - // with the data available in |sequencer| and |level|, and marks the data - // passed to ProcessInput as consumed. - void OnDataAvailableInSequencer(QuicStreamSequencer* sequencer, - EncryptionLevel level); - // Consumed data according to encryption levels. // TODO(fayang): This is not needed once switching from QUIC crypto to // TLS 1.3, which never encrypts crypto data. diff --git a/gquiche/quic/core/quic_crypto_stream_test.cc b/gquiche/quic/core/quic_crypto_stream_test.cc index d6749614..16e52b8c 100644 --- a/gquiche/quic/core/quic_crypto_stream_test.cc +++ b/gquiche/quic/core/quic_crypto_stream_test.cc @@ -64,10 +64,19 @@ class MockQuicCryptoStream : public QuicCryptoStream, void OnHandshakePacketSent() override {} void OnHandshakeDoneReceived() override {} void OnNewTokenReceived(absl::string_view /*token*/) override {} - std::string GetAddressToken() const override { return ""; } + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } bool ValidateAddressToken(absl::string_view /*token*/) const override { return true; } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; } void SetServerApplicationStateForResumption( std::unique_ptr /*application_state*/) override {} @@ -79,6 +88,13 @@ class MockQuicCryptoStream : public QuicCryptoStream, std::unique_ptr CreateCurrentOneRttEncrypter() override { return nullptr; } + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + SSL* GetSsl() const override { return nullptr; } private: QuicReferenceCountedPointer params_; @@ -689,6 +705,21 @@ TEST_F(QuicCryptoStreamTest, RetransmitCryptoFramesAndPartialWrite) { EXPECT_FALSE(stream_->HasPendingCryptoRetransmission()); } +// Regression test for b/203199510 +TEST_F(QuicCryptoStreamTest, EmptyCryptoFrame) { + if (!QuicVersionUsesCryptoFrames(connection_->transport_version())) { + return; + } + if (GetQuicReloadableFlag(quic_accept_empty_crypto_frame)) { + EXPECT_CALL(*connection_, CloseConnection(_, _, _)).Times(0); + } else { + EXPECT_CALL(*connection_, + CloseConnection(QUIC_EMPTY_STREAM_FRAME_NO_FIN, _, _)); + } + QuicCryptoFrame empty_crypto_frame(ENCRYPTION_INITIAL, 0, nullptr, 0); + stream_->OnCryptoFrame(empty_crypto_frame); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/quic_datagram_queue.cc b/gquiche/quic/core/quic_datagram_queue.cc index 4abf28fc..410fbc66 100644 --- a/gquiche/quic/core/quic_datagram_queue.cc +++ b/gquiche/quic/core/quic_datagram_queue.cc @@ -4,11 +4,11 @@ #include "gquiche/quic/core/quic_datagram_queue.h" +#include "absl/types/span.h" #include "gquiche/quic/core/quic_constants.h" #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/core/quic_types.h" -#include "gquiche/quic/platform/api/quic_mem_slice_span.h" namespace quic { @@ -30,8 +30,7 @@ MessageStatus QuicDatagramQueue::SendOrQueueDatagram(QuicMemSlice datagram) { // the datagrams are sent in the same order that they were sent by the // application. if (queue_.empty()) { - QuicMemSliceSpan span(&datagram); - MessageResult result = session_->SendMessage(span); + MessageResult result = session_->SendMessage(absl::MakeSpan(&datagram, 1)); if (result.status != MESSAGE_STATUS_BLOCKED) { if (observer_) { observer_->OnDatagramProcessed(result.status); @@ -51,8 +50,8 @@ absl::optional QuicDatagramQueue::TrySendingNextDatagram() { return absl::nullopt; } - QuicMemSliceSpan span(&queue_.front().datagram); - MessageResult result = session_->SendMessage(span); + MessageResult result = + session_->SendMessage(absl::MakeSpan(&queue_.front().datagram, 1)); if (result.status != MESSAGE_STATUS_BLOCKED) { queue_.pop_front(); if (observer_) { diff --git a/gquiche/quic/core/quic_datagram_queue.h b/gquiche/quic/core/quic_datagram_queue.h index 172a08de..b887118b 100644 --- a/gquiche/quic/core/quic_datagram_queue.h +++ b/gquiche/quic/core/quic_datagram_queue.h @@ -8,10 +8,10 @@ #include #include "absl/types/optional.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_mem_slice.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -80,7 +80,7 @@ class QUIC_EXPORT_PRIVATE QuicDatagramQueue { const QuicClock* clock_; QuicTime::Delta max_time_in_queue_ = QuicTime::Delta::Zero(); - QuicCircularDeque queue_; + quiche::QuicheCircularDeque queue_; std::unique_ptr observer_; }; diff --git a/gquiche/quic/core/quic_datagram_queue_test.cc b/gquiche/quic/core/quic_datagram_queue_test.cc index 7ca59670..0fa526fa 100644 --- a/gquiche/quic/core/quic_datagram_queue_test.cc +++ b/gquiche/quic/core/quic_datagram_queue_test.cc @@ -176,8 +176,9 @@ TEST_F(QuicDatagramQueueTest, Expiry) { std::vector messages; EXPECT_CALL(*connection_, SendMessage(_, _, _)) .WillRepeatedly([&messages](QuicMessageId /*id*/, - QuicMemSliceSpan message, bool /*flush*/) { - messages.push_back(std::string(message.GetData(0))); + absl::Span message, + bool /*flush*/) { + messages.push_back(std::string(message[0].AsStringView())); return MESSAGE_STATUS_SUCCESS; }); EXPECT_EQ(2u, queue_.SendDatagrams()); diff --git a/gquiche/quic/core/quic_dispatcher.cc b/gquiche/quic/core/quic_dispatcher.cc index 8ed84a06..9c0af8c8 100644 --- a/gquiche/quic/core/quic_dispatcher.cc +++ b/gquiche/quic/core/quic_dispatcher.cc @@ -26,7 +26,7 @@ #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_stack_trace.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { @@ -40,7 +40,7 @@ namespace { const QuicPacketLength kMinClientInitialPacketLength = 1200; // An alarm that informs the QuicDispatcher to delete old sessions. -class DeleteSessionsAlarm : public QuicAlarm::Delegate { +class DeleteSessionsAlarm : public QuicAlarm::DelegateWithoutContext { public: explicit DeleteSessionsAlarm(QuicDispatcher* dispatcher) : dispatcher_(dispatcher) {} @@ -54,6 +54,24 @@ class DeleteSessionsAlarm : public QuicAlarm::Delegate { QuicDispatcher* dispatcher_; }; +// An alarm that informs the QuicDispatcher to clear +// recent_stateless_reset_addresses_. +class ClearStatelessResetAddressesAlarm + : public QuicAlarm::DelegateWithoutContext { + public: + explicit ClearStatelessResetAddressesAlarm(QuicDispatcher* dispatcher) + : dispatcher_(dispatcher) {} + ClearStatelessResetAddressesAlarm(const DeleteSessionsAlarm&) = delete; + ClearStatelessResetAddressesAlarm& operator=(const DeleteSessionsAlarm&) = + delete; + + void OnAlarm() override { dispatcher_->ClearStatelessResetAddresses(); } + + private: + // Not owned. + QuicDispatcher* dispatcher_; +}; + // Collects packets serialized by a QuicPacketCreator in order // to be handed off to the time wait list manager. class PacketCollector : public QuicPacketCreator::DelegateInterface, @@ -159,7 +177,6 @@ class StatelessConnectionTerminator { SerializeConnectionClosePacket(error_code, error_details); time_wait_list_manager_->AddConnectionIdToTimeWait( - server_connection_id_, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, TimeWaitConnectionInfo(ietf_quic, collector_.packets(), std::move(active_connection_ids), @@ -205,6 +222,10 @@ class ChloAlpnSniExtractor : public ChloExtractor::Delegate { if (chlo.GetStringPiece(quic::kSNI, &sni)) { sni_ = std::string(sni); } + absl::string_view uaid_value; + if (chlo.GetStringPiece(quic::kUAID, &uaid_value)) { + uaid_ = std::string(uaid_value); + } if (version == LegacyVersionForEncapsulation().transport_version) { absl::string_view qlve_value; if (chlo.GetStringPiece(kQLVE, &qlve_value)) { @@ -217,6 +238,8 @@ class ChloAlpnSniExtractor : public ChloExtractor::Delegate { std::string&& ConsumeSni() { return std::move(sni_); } + std::string&& ConsumeUaid() { return std::move(uaid_); } + std::string&& ConsumeLegacyVersionEncapsulationInnerPacket() { return std::move(legacy_version_encapsulation_inner_packet_); } @@ -224,15 +247,14 @@ class ChloAlpnSniExtractor : public ChloExtractor::Delegate { private: std::string alpn_; std::string sni_; + std::string uaid_; std::string legacy_version_encapsulation_inner_packet_; }; bool MaybeHandleLegacyVersionEncapsulation( QuicDispatcher* dispatcher, - ChloAlpnSniExtractor* alpn_extractor, + std::string legacy_version_encapsulation_inner_packet, const ReceivedPacketInfo& packet_info) { - std::string legacy_version_encapsulation_inner_packet = - alpn_extractor->ConsumeLegacyVersionEncapsulationInnerPacket(); if (legacy_version_encapsulation_inner_packet.empty()) { // This CHLO did not contain the Legacy Version Encapsulation tag. return false; @@ -244,16 +266,15 @@ bool MaybeHandleLegacyVersionEncapsulation( QuicVersionLabel version_label; ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported(); QuicConnectionId destination_connection_id, source_connection_id; - bool retry_token_present; - absl::string_view retry_token; + absl::optional retry_token; std::string detailed_error; const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( QuicEncryptedPacket(legacy_version_encapsulation_inner_packet.data(), legacy_version_encapsulation_inner_packet.length()), kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_present, &has_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); if (error != QUIC_NO_ERROR) { QUIC_DLOG(ERROR) << "Failed to parse Legacy Version Encapsulation inner packet:" @@ -306,8 +327,7 @@ bool MaybeHandleLegacyVersionEncapsulation( } // namespace QuicDispatcher::QuicDispatcher( - const QuicConfig* config, - const QuicCryptoServerConfig* crypto_config, + const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, QuicVersionManager* version_manager, std::unique_ptr helper, std::unique_ptr session_helper, @@ -330,10 +350,9 @@ QuicDispatcher::QuicDispatcher( allow_short_initial_server_connection_ids_(false), expected_server_connection_id_length_( expected_server_connection_id_length), + clear_stateless_reset_addresses_alarm_(alarm_factory_->CreateAlarm( + new ClearStatelessResetAddressesAlarm(this))), should_update_expected_server_connection_id_length_(false) { - if (use_reference_counted_session_map_) { - QUIC_RESTART_FLAG_COUNT(quic_use_reference_counted_sesssion_map); - } QUIC_BUG_IF(quic_bug_12724_1, GetSupportedVersions().empty()) << "Trying to create dispatcher without any supported versions"; QUIC_DLOG(INFO) << "Created QuicDispatcher with versions: " @@ -341,16 +360,15 @@ QuicDispatcher::QuicDispatcher( } QuicDispatcher::~QuicDispatcher() { - if (use_reference_counted_session_map_) { - reference_counted_session_map_.clear(); - closed_ref_counted_session_list_.clear(); - if (support_multiple_cid_per_connection_) { - num_sessions_in_session_map_ = 0; - } - } else { - session_map_.clear(); - closed_session_list_.clear(); + if (delete_sessions_alarm_ != nullptr) { + delete_sessions_alarm_->PermanentCancel(); } + if (clear_stateless_reset_addresses_alarm_ != nullptr) { + clear_stateless_reset_addresses_alarm_->PermanentCancel(); + } + reference_counted_session_map_.clear(); + closed_ref_counted_session_list_.clear(); + num_sessions_in_session_map_ = 0; } void QuicDispatcher::InitializeWithWriter(QuicPacketWriter* writer) { @@ -368,14 +386,12 @@ void QuicDispatcher::ProcessPacket(const QuicSocketAddress& self_address, absl::string_view(packet.data(), packet.length())); ReceivedPacketInfo packet_info(self_address, peer_address, packet); std::string detailed_error; - bool retry_token_present; - absl::string_view retry_token; const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( packet, expected_server_connection_id_length_, &packet_info.form, &packet_info.long_packet_type, &packet_info.version_flag, &packet_info.use_length_prefix, &packet_info.version_label, &packet_info.version, &packet_info.destination_connection_id, - &packet_info.source_connection_id, &retry_token_present, &retry_token, + &packet_info.source_connection_id, &packet_info.retry_token, &detailed_error); if (error != QUIC_NO_ERROR) { // Packet has framing error. @@ -478,12 +494,49 @@ QuicConnectionId QuicDispatcher::ReplaceLongServerConnectionId( server_connection_id, expected_server_connection_id_length); } +namespace { +constexpr bool IsSourceUdpPortBlocked(uint16_t port) { + // These UDP source ports have been observed in large scale denial of service + // attacks and are not expected to ever carry user traffic, they are therefore + // blocked as a safety measure. See draft-ietf-quic-applicability for details. + constexpr uint16_t blocked_ports[] = { + 0, // We cannot send to port 0 so drop that source port. + 17, // Quote of the Day, can loop with QUIC. + 19, // Chargen, can loop with QUIC. + 53, // DNS, vulnerable to reflection attacks. + 111, // Portmap. + 123, // NTP, vulnerable to reflection attacks. + 137, // NETBIOS Name Service, + 128, // NETBIOS Datagram Service + 161, // SNMP. + 389, // CLDAP. + 500, // IKE, can loop with QUIC. + 1900, // SSDP, vulnerable to reflection attacks. + 5353, // mDNS, vulnerable to reflection attacks. + 11211, // memcache, vulnerable to reflection attacks. + // This list MUST be sorted in increasing order. + }; + constexpr size_t num_blocked_ports = ABSL_ARRAYSIZE(blocked_ports); + constexpr uint16_t highest_blocked_port = + blocked_ports[num_blocked_ports - 1]; + if (QUICHE_PREDICT_TRUE(port > highest_blocked_port)) { + // Early-return to skip comparisons for the majority of traffic. + return false; + } + for (size_t i = 0; i < num_blocked_ports; i++) { + if (port == blocked_ports[i]) { + return true; + } + } + return false; +} +} // namespace + bool QuicDispatcher::MaybeDispatchPacket( const ReceivedPacketInfo& packet_info) { - // Port zero is only allowed for unidirectional UDP, so is disallowed by QUIC. - // Given that we can't even send a reply rejecting the packet, just drop the - // packet. - if (packet_info.peer_address.port() == 0) { + if (IsSourceUdpPortBlocked(packet_info.peer_address.port())) { + // Silently drop the received packet. + QUIC_CODE_COUNT(quic_dropped_blocked_port); return true; } @@ -510,60 +563,44 @@ bool QuicDispatcher::MaybeDispatchPacket( return true; } + if (packet_info.version_flag && packet_info.version.IsKnown() && + !QuicUtils::IsConnectionIdLengthValidForVersion( + server_connection_id.length(), + packet_info.version.transport_version)) { + QUIC_DLOG(INFO) << "Packet with destination connection ID " + << server_connection_id << " is invalid with version " + << packet_info.version; + // Drop the packet silently. + QUIC_CODE_COUNT(quic_dropped_invalid_initial_connection_id); + return true; + } + // Packets with connection IDs for active connections are processed // immediately. - if (use_reference_counted_session_map_) { - auto it = reference_counted_session_map_.find(server_connection_id); - if (it != reference_counted_session_map_.end()) { - QUICHE_DCHECK( - !buffered_packets_.HasBufferedPackets(server_connection_id)); - if (packet_info.version_flag && - packet_info.version != it->second->version() && - packet_info.version == LegacyVersionForEncapsulation()) { - // This packet is using the Legacy Version Encapsulation version but the - // corresponding session isn't, attempt extraction of inner packet. - ChloAlpnSniExtractor alpn_extractor; - if (ChloExtractor::Extract(packet_info.packet, packet_info.version, - config_->create_session_tag_indicators(), - &alpn_extractor, - server_connection_id.length())) { - if (MaybeHandleLegacyVersionEncapsulation(this, &alpn_extractor, - packet_info)) { - return true; - } - } - } - it->second->ProcessUdpPacket(packet_info.self_address, - packet_info.peer_address, - packet_info.packet); - return true; - } - } else { - auto it = session_map_.find(server_connection_id); - if (it != session_map_.end()) { - QUICHE_DCHECK( - !buffered_packets_.HasBufferedPackets(server_connection_id)); - if (packet_info.version_flag && - packet_info.version != it->second->version() && - packet_info.version == LegacyVersionForEncapsulation()) { - // This packet is using the Legacy Version Encapsulation version but the - // corresponding session isn't, attempt extraction of inner packet. - ChloAlpnSniExtractor alpn_extractor; - if (ChloExtractor::Extract(packet_info.packet, packet_info.version, - config_->create_session_tag_indicators(), - &alpn_extractor, - server_connection_id.length())) { - if (MaybeHandleLegacyVersionEncapsulation(this, &alpn_extractor, - packet_info)) { - return true; - } + auto it = reference_counted_session_map_.find(server_connection_id); + if (it != reference_counted_session_map_.end()) { + QUICHE_DCHECK(!buffered_packets_.HasBufferedPackets(server_connection_id)); + if (packet_info.version_flag && + packet_info.version != it->second->version() && + packet_info.version == LegacyVersionForEncapsulation()) { + // This packet is using the Legacy Version Encapsulation version but the + // corresponding session isn't, attempt extraction of inner packet. + ChloAlpnSniExtractor alpn_extractor; + if (ChloExtractor::Extract(packet_info.packet, packet_info.version, + config_->create_session_tag_indicators(), + &alpn_extractor, + server_connection_id.length())) { + if (MaybeHandleLegacyVersionEncapsulation( + this, + alpn_extractor.ConsumeLegacyVersionEncapsulationInnerPacket(), + packet_info)) { + return true; } } - it->second->ProcessUdpPacket(packet_info.self_address, - packet_info.peer_address, - packet_info.packet); - return true; } + it->second->ProcessUdpPacket(packet_info.self_address, + packet_info.peer_address, packet_info.packet); + return true; } if (packet_info.version.IsKnown()) { // We did not find the connection ID, check if we've replaced it. @@ -575,28 +612,15 @@ bool QuicDispatcher::MaybeDispatchPacket( QuicConnectionId replaced_connection_id = MaybeReplaceServerConnectionId( server_connection_id, packet_info.version); if (replaced_connection_id != server_connection_id) { - if (use_reference_counted_session_map_) { - // Search for the replacement. - auto it2 = reference_counted_session_map_.find(replaced_connection_id); - if (it2 != reference_counted_session_map_.end()) { - QUICHE_DCHECK( - !buffered_packets_.HasBufferedPackets(replaced_connection_id)); - it2->second->ProcessUdpPacket(packet_info.self_address, - packet_info.peer_address, - packet_info.packet); - return true; - } - } else { - // Search for the replacement. - auto it2 = session_map_.find(replaced_connection_id); - if (it2 != session_map_.end()) { - QUICHE_DCHECK( - !buffered_packets_.HasBufferedPackets(replaced_connection_id)); - it2->second->ProcessUdpPacket(packet_info.self_address, - packet_info.peer_address, - packet_info.packet); - return true; - } + // Search for the replacement. + auto it2 = reference_counted_session_map_.find(replaced_connection_id); + if (it2 != reference_counted_session_map_.end()) { + QUICHE_DCHECK( + !buffered_packets_.HasBufferedPackets(replaced_connection_id)); + it2->second->ProcessUdpPacket(packet_info.self_address, + packet_info.peer_address, + packet_info.packet); + return true; } } } @@ -683,76 +707,38 @@ void QuicDispatcher::ProcessHeader(ReceivedPacketInfo* packet_info) { packet_info->destination_connection_id; // Packet's connection ID is unknown. Apply the validity checks. QuicPacketFate fate = ValidityChecks(*packet_info); - ChloAlpnSniExtractor alpn_extractor; - switch (fate) { - case kFateProcess: { - if (packet_info->version.handshake_protocol == PROTOCOL_TLS1_3) { - bool has_full_tls_chlo = false; - std::string sni; - std::vector alpns; - if (buffered_packets_.HasBufferedPackets( - packet_info->destination_connection_id)) { - // If we already have buffered packets for this connection ID, - // use the associated TlsChloExtractor to parse this packet. - has_full_tls_chlo = - buffered_packets_.IngestPacketForTlsChloExtraction( - packet_info->destination_connection_id, packet_info->version, - packet_info->packet, &alpns, &sni); - } else { - // If we do not have a BufferedPacketList for this connection ID, - // create a single-use one to check whether this packet contains a - // full single-packet CHLO. - TlsChloExtractor tls_chlo_extractor; - tls_chlo_extractor.IngestPacket(packet_info->version, - packet_info->packet); - if (tls_chlo_extractor.HasParsedFullChlo()) { - // This packet contains a full single-packet CHLO. - has_full_tls_chlo = true; - alpns = tls_chlo_extractor.alpns(); - sni = tls_chlo_extractor.server_name(); - } - } - if (has_full_tls_chlo) { - ProcessChlo(alpns, sni, packet_info); - } else { - // This packet does not contain a full CHLO. It could be a 0-RTT - // packet that arrived before the CHLO (due to loss or reordering), - // or it could be a fragment of a multi-packet CHLO. - BufferEarlyPacket(*packet_info); - } - break; - } - if (GetQuicFlag(FLAGS_quic_allow_chlo_buffering) && - !ChloExtractor::Extract(packet_info->packet, packet_info->version, - config_->create_session_tag_indicators(), - &alpn_extractor, - server_connection_id.length())) { - // Buffer non-CHLO packets. - BufferEarlyPacket(*packet_info); - break; - } - // We only apply this check for versions that do not use the IETF - // invariant header because those versions are already checked in - // QuicDispatcher::MaybeDispatchPacket. - if (packet_info->version_flag && - !packet_info->version.HasIetfInvariantHeader() && - crypto_config()->validate_chlo_size() && - packet_info->packet.length() < kMinClientInitialPacketLength) { - QUIC_DVLOG(1) << "Dropping CHLO packet which is too short, length: " - << packet_info->packet.length(); - QUIC_CODE_COUNT(quic_drop_small_chlo_packets); - break; - } + if (fate == kFateProcess) { + absl::optional parsed_chlo = + TryExtractChloOrBufferEarlyPacket(*packet_info); + if (!parsed_chlo.has_value()) { + // Client Hello incomplete. Packet has been buffered or (rarely) dropped. + return; + } - if (MaybeHandleLegacyVersionEncapsulation(this, &alpn_extractor, - *packet_info)) { - break; + // Client Hello fully received. + fate = ValidityChecksOnFullChlo(*packet_info, *parsed_chlo); + + if (fate == kFateProcess) { + QUICHE_DCHECK( + parsed_chlo->legacy_version_encapsulation_inner_packet.empty() || + !packet_info->version.UsesTls()); + if (MaybeHandleLegacyVersionEncapsulation( + this, parsed_chlo->legacy_version_encapsulation_inner_packet, + *packet_info)) { + return; } - ProcessChlo({alpn_extractor.ConsumeAlpn()}, alpn_extractor.ConsumeSni(), - packet_info); - } break; + ProcessChlo(*std::move(parsed_chlo), packet_info); + return; + } + } + + switch (fate) { + case kFateProcess: + // kFateProcess have been processed above. + QUIC_BUG(quic_dispatcher_bad_packet_fate) << fate; + break; case kFateTimeWait: // Add this connection_id to the time-wait state, to safely reject // future packets. @@ -779,6 +765,89 @@ void QuicDispatcher::ProcessHeader(ReceivedPacketInfo* packet_info) { } } +absl::optional +QuicDispatcher::TryExtractChloOrBufferEarlyPacket( + const ReceivedPacketInfo& packet_info) { + if (packet_info.version.UsesTls()) { + bool has_full_tls_chlo = false; + std::string sni; + std::vector alpns; + bool resumption_attempted = false, early_data_attempted = false; + if (buffered_packets_.HasBufferedPackets( + packet_info.destination_connection_id)) { + // If we already have buffered packets for this connection ID, + // use the associated TlsChloExtractor to parse this packet. + has_full_tls_chlo = buffered_packets_.IngestPacketForTlsChloExtraction( + packet_info.destination_connection_id, packet_info.version, + packet_info.packet, &alpns, &sni, &resumption_attempted, + &early_data_attempted); + } else { + // If we do not have a BufferedPacketList for this connection ID, + // create a single-use one to check whether this packet contains a + // full single-packet CHLO. + TlsChloExtractor tls_chlo_extractor; + tls_chlo_extractor.IngestPacket(packet_info.version, packet_info.packet); + if (tls_chlo_extractor.HasParsedFullChlo()) { + // This packet contains a full single-packet CHLO. + has_full_tls_chlo = true; + alpns = tls_chlo_extractor.alpns(); + sni = tls_chlo_extractor.server_name(); + resumption_attempted = tls_chlo_extractor.resumption_attempted(); + early_data_attempted = tls_chlo_extractor.early_data_attempted(); + } + } + if (!has_full_tls_chlo) { + // This packet does not contain a full CHLO. It could be a 0-RTT + // packet that arrived before the CHLO (due to loss or reordering), + // or it could be a fragment of a multi-packet CHLO. + BufferEarlyPacket(packet_info); + return absl::nullopt; + } + + ParsedClientHello parsed_chlo; + parsed_chlo.sni = std::move(sni); + parsed_chlo.alpns = std::move(alpns); + if (packet_info.retry_token.has_value()) { + parsed_chlo.retry_token = std::string(*packet_info.retry_token); + } + parsed_chlo.resumption_attempted = resumption_attempted; + parsed_chlo.early_data_attempted = early_data_attempted; + return parsed_chlo; + } + + ChloAlpnSniExtractor alpn_extractor; + if (GetQuicFlag(FLAGS_quic_allow_chlo_buffering) && + !ChloExtractor::Extract(packet_info.packet, packet_info.version, + config_->create_session_tag_indicators(), + &alpn_extractor, + packet_info.destination_connection_id.length())) { + // Buffer non-CHLO packets. + BufferEarlyPacket(packet_info); + return absl::nullopt; + } + + // We only apply this check for versions that do not use the IETF + // invariant header because those versions are already checked in + // QuicDispatcher::MaybeDispatchPacket. + if (packet_info.version_flag && + !packet_info.version.HasIetfInvariantHeader() && + crypto_config()->validate_chlo_size() && + packet_info.packet.length() < kMinClientInitialPacketLength) { + QUIC_DVLOG(1) << "Dropping CHLO packet which is too short, length: " + << packet_info.packet.length(); + QUIC_CODE_COUNT(quic_drop_small_chlo_packets); + return absl::nullopt; + } + + ParsedClientHello parsed_chlo; + parsed_chlo.legacy_version_encapsulation_inner_packet = + alpn_extractor.ConsumeLegacyVersionEncapsulationInnerPacket(); + parsed_chlo.sni = alpn_extractor.ConsumeSni(); + parsed_chlo.uaid = alpn_extractor.ConsumeUaid(); + parsed_chlo.alpns = {alpn_extractor.ConsumeAlpn()}; + return parsed_chlo; +} + std::string QuicDispatcher::SelectAlpn(const std::vector& alpns) { if (alpns.empty()) { return ""; @@ -812,9 +881,9 @@ QuicDispatcher::QuicPacketFate QuicDispatcher::ValidityChecks( void QuicDispatcher::CleanUpSession(QuicConnectionId server_connection_id, QuicConnection* connection, - QuicErrorCode error, - const std::string& error_details, - ConnectionCloseSource source) { + QuicErrorCode /*error*/, + const std::string& /*error_details*/, + ConnectionCloseSource /*source*/) { write_blocked_list_.erase(connection); QuicTimeWaitListManager::TimeWaitAction action = QuicTimeWaitListManager::SEND_STATELESS_RESET; @@ -823,61 +892,29 @@ void QuicDispatcher::CleanUpSession(QuicConnectionId server_connection_id, action = QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS; } else { if (!connection->IsHandshakeComplete()) { - const bool fix_dispatcher_sent_error_code = - GetQuicReloadableFlag(quic_fix_dispatcher_sent_error_code) && - source == ConnectionCloseSource::FROM_SELF; // TODO(fayang): Do not serialize connection close packet if the // connection is closed by the client. - if (fix_dispatcher_sent_error_code) { - QUIC_RELOADABLE_FLAG_COUNT(quic_fix_dispatcher_sent_error_code); - } if (!connection->version().HasIetfInvariantHeader()) { QUIC_CODE_COUNT(gquic_add_to_time_wait_list_with_handshake_failed); } else { QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_handshake_failed); } - if (support_multiple_cid_per_connection_) { - QUIC_RESTART_FLAG_COUNT_N( - quic_dispatcher_support_multiple_cid_per_connection_v2, 1, 2); - // This serializes a connection close termination packet and adds the - // connection to the time wait list. - StatelessConnectionTerminator terminator( - server_connection_id, connection->version(), helper_.get(), - time_wait_list_manager_.get()); - terminator.CloseConnection( - fix_dispatcher_sent_error_code ? error : QUIC_HANDSHAKE_FAILED, - fix_dispatcher_sent_error_code - ? error_details - : "Connection is closed by server before handshake confirmed", - connection->version().HasIetfInvariantHeader(), - connection->GetActiveServerConnectionIds()); - } else { - action = QuicTimeWaitListManager::SEND_TERMINATION_PACKETS; - // This serializes a connection close termination packet and adds the - // connection to the time wait list. - StatelesslyTerminateConnection( - connection->connection_id(), - connection->version().HasIetfInvariantHeader() - ? IETF_QUIC_LONG_HEADER_PACKET - : GOOGLE_QUIC_PACKET, - /*version_flag=*/true, - connection->version().HasLengthPrefixedConnectionIds(), - connection->version(), - fix_dispatcher_sent_error_code ? error : QUIC_HANDSHAKE_FAILED, - fix_dispatcher_sent_error_code - ? error_details - : "Connection is closed by server before handshake confirmed", - // Although it is our intention to send termination packets, the - // |action| argument is not used by this call to - // StatelesslyTerminateConnection(). - action); - } + // This serializes a connection close termination packet and adds the + // connection to the time wait list. + StatelessConnectionTerminator terminator( + server_connection_id, connection->version(), helper_.get(), + time_wait_list_manager_.get()); + terminator.CloseConnection( + QUIC_HANDSHAKE_FAILED, + "Connection is closed by server before handshake confirmed", + connection->version().HasIetfInvariantHeader(), + connection->GetActiveServerConnectionIds()); return; } QUIC_CODE_COUNT(quic_v44_add_to_time_wait_list_with_stateless_reset); } time_wait_list_manager_->AddConnectionIdToTimeWait( - server_connection_id, action, + action, TimeWaitConnectionInfo( connection->version().HasIetfInvariantHeader(), connection->termination_packets(), @@ -898,18 +935,12 @@ void QuicDispatcher::StopAcceptingNewConnections() { void QuicDispatcher::PerformActionOnActiveSessions( std::function operation) const { - if (use_reference_counted_session_map_) { - absl::flat_hash_set visited_session; - visited_session.reserve(reference_counted_session_map_.size()); - for (auto const& kv : reference_counted_session_map_) { - QuicSession* session = kv.second.get(); - if (visited_session.insert(session).second) { - operation(session); - } - } - } else { - for (auto const& kv : session_map_) { - operation(kv.second.get()); + absl::flat_hash_set visited_session; + visited_session.reserve(reference_counted_session_map_.size()); + for (auto const& kv : reference_counted_session_map_) { + QuicSession* session = kv.second.get(); + if (visited_session.insert(session).second) { + operation(session); } } } @@ -917,7 +948,6 @@ void QuicDispatcher::PerformActionOnActiveSessions( // Get a snapshot of all sessions. std::vector> QuicDispatcher::GetSessionsSnapshot() const { - QUICHE_DCHECK(use_reference_counted_session_map_); std::vector> snapshot; snapshot.reserve(reference_counted_session_map_.size()); absl::flat_hash_set visited_session; @@ -937,29 +967,20 @@ std::unique_ptr QuicDispatcher::GetPerPacketContext() } void QuicDispatcher::DeleteSessions() { - if (use_reference_counted_session_map_) { - if (!write_blocked_list_.empty()) { - for (const auto& session : closed_ref_counted_session_list_) { - if (write_blocked_list_.erase(session->connection()) != 0) { - QUIC_BUG(quic_bug_12724_2) - << "QuicConnection was in WriteBlockedList before destruction " - << session->connection()->connection_id(); - } + if (!write_blocked_list_.empty()) { + for (const auto& session : closed_ref_counted_session_list_) { + if (write_blocked_list_.erase(session->connection()) != 0) { + QUIC_BUG(quic_bug_12724_2) + << "QuicConnection was in WriteBlockedList before destruction " + << session->connection()->connection_id(); } } - closed_ref_counted_session_list_.clear(); - } else { - if (!write_blocked_list_.empty()) { - for (const std::unique_ptr& session : closed_session_list_) { - if (write_blocked_list_.erase(session->connection()) != 0) { - QUIC_BUG(quic_bug_12724_3) - << "QuicConnection was in WriteBlockedList before destruction " - << session->connection()->connection_id(); - } - } - } - closed_session_list_.clear(); } + closed_ref_counted_session_list_.clear(); +} + +void QuicDispatcher::ClearStatelessResetAddresses() { + recent_stateless_reset_addresses_.clear(); } void QuicDispatcher::OnCanWrite() { @@ -995,28 +1016,15 @@ bool QuicDispatcher::HasPendingWrites() const { } void QuicDispatcher::Shutdown() { - if (use_reference_counted_session_map_) { - while (!reference_counted_session_map_.empty()) { - QuicSession* session = - reference_counted_session_map_.begin()->second.get(); - session->connection()->CloseConnection( - QUIC_PEER_GOING_AWAY, "Server shutdown imminent", - ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); - // Validate that the session removes itself from the session map on close. - QUICHE_DCHECK(reference_counted_session_map_.empty() || - reference_counted_session_map_.begin()->second.get() != - session); - } - } else { - while (!session_map_.empty()) { - QuicSession* session = session_map_.begin()->second.get(); - session->connection()->CloseConnection( - QUIC_PEER_GOING_AWAY, "Server shutdown imminent", - ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); - // Validate that the session removes itself from the session map on close. - QUICHE_DCHECK(session_map_.empty() || - session_map_.begin()->second.get() != session); - } + while (!reference_counted_session_map_.empty()) { + QuicSession* session = reference_counted_session_map_.begin()->second.get(); + session->connection()->CloseConnection( + QUIC_PEER_GOING_AWAY, "Server shutdown imminent", + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + // Validate that the session removes itself from the session map on close. + QUICHE_DCHECK(reference_counted_session_map_.empty() || + reference_counted_session_map_.begin()->second.get() != + session); } DeleteSessions(); } @@ -1025,73 +1033,36 @@ void QuicDispatcher::OnConnectionClosed(QuicConnectionId server_connection_id, QuicErrorCode error, const std::string& error_details, ConnectionCloseSource source) { - if (use_reference_counted_session_map_) { - auto it = reference_counted_session_map_.find(server_connection_id); - if (it == reference_counted_session_map_.end()) { - QUIC_BUG(quic_bug_10287_3) - << "ConnectionId " << server_connection_id - << " does not exist in the session map. Error: " - << QuicErrorCodeToString(error); - QUIC_BUG(quic_bug_10287_4) << QuicStackTrace(); - return; - } + auto it = reference_counted_session_map_.find(server_connection_id); + if (it == reference_counted_session_map_.end()) { + QUIC_BUG(quic_bug_10287_3) << "ConnectionId " << server_connection_id + << " does not exist in the session map. Error: " + << QuicErrorCodeToString(error); + QUIC_BUG(quic_bug_10287_4) << QuicStackTrace(); + return; + } - QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR) - << "Closing connection (" << server_connection_id - << ") due to error: " << QuicErrorCodeToString(error) - << ", with details: " << error_details; - - QuicConnection* connection = it->second->connection(); - if (ShouldDestroySessionAsynchronously()) { - // Set up alarm to fire immediately to bring destruction of this session - // out of current call stack. - if (closed_ref_counted_session_list_.empty()) { - delete_sessions_alarm_->Update(helper()->GetClock()->ApproximateNow(), - QuicTime::Delta::Zero()); - } - closed_ref_counted_session_list_.push_back(std::move(it->second)); - } - CleanUpSession(it->first, connection, error, error_details, source); - if (support_multiple_cid_per_connection_) { - QUIC_RESTART_FLAG_COUNT_N( - quic_dispatcher_support_multiple_cid_per_connection_v2, 1, 2); - for (const QuicConnectionId& cid : - connection->GetActiveServerConnectionIds()) { - reference_counted_session_map_.erase(cid); - } - --num_sessions_in_session_map_; - } else { - reference_counted_session_map_.erase(it); - } - } else { - auto it = session_map_.find(server_connection_id); - if (it == session_map_.end()) { - QUIC_BUG(quic_bug_10287_5) - << "ConnectionId " << server_connection_id - << " does not exist in the session map. Error: " - << QuicErrorCodeToString(error); - QUIC_BUG(quic_bug_10287_6) << QuicStackTrace(); - return; - } + QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR) + << "Closing connection (" << server_connection_id + << ") due to error: " << QuicErrorCodeToString(error) + << ", with details: " << error_details; - QUIC_DLOG_IF(INFO, error != QUIC_NO_ERROR) - << "Closing connection (" << server_connection_id - << ") due to error: " << QuicErrorCodeToString(error) - << ", with details: " << error_details; - - QuicConnection* connection = it->second->connection(); - if (ShouldDestroySessionAsynchronously()) { - // Set up alarm to fire immediately to bring destruction of this session - // out of current call stack. - if (closed_session_list_.empty()) { - delete_sessions_alarm_->Update(helper()->GetClock()->ApproximateNow(), - QuicTime::Delta::Zero()); - } - closed_session_list_.push_back(std::move(it->second)); + QuicConnection* connection = it->second->connection(); + if (ShouldDestroySessionAsynchronously()) { + // Set up alarm to fire immediately to bring destruction of this session + // out of current call stack. + if (closed_ref_counted_session_list_.empty()) { + delete_sessions_alarm_->Update(helper()->GetClock()->ApproximateNow(), + QuicTime::Delta::Zero()); } - CleanUpSession(it->first, connection, error, error_details, source); - session_map_.erase(it); + closed_ref_counted_session_list_.push_back(std::move(it->second)); } + CleanUpSession(it->first, connection, error, error_details, source); + for (const QuicConnectionId& cid : + connection->GetActiveServerConnectionIds()) { + reference_counted_session_map_.erase(cid); + } + --num_sessions_in_session_map_; } void QuicDispatcher::OnWriteBlocked( @@ -1117,7 +1088,6 @@ void QuicDispatcher::OnStopSendingReceived( void QuicDispatcher::OnNewConnectionIdSent( const QuicConnectionId& server_connection_id, const QuicConnectionId& new_connection_id) { - QUICHE_DCHECK(support_multiple_cid_per_connection_); auto it = reference_counted_session_map_.find(server_connection_id); if (it == reference_counted_session_map_.end()) { QUIC_BUG(quic_bug_10287_7) @@ -1126,6 +1096,8 @@ void QuicDispatcher::OnNewConnectionIdSent( << server_connection_id << " new_connection_id: " << new_connection_id; return; } + // Count new connection ID added to the dispatcher map. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 6, 6); auto insertion_result = reference_counted_session_map_.insert( std::make_pair(new_connection_id, it->second)); QUICHE_DCHECK(insertion_result.second); @@ -1133,7 +1105,6 @@ void QuicDispatcher::OnNewConnectionIdSent( void QuicDispatcher::OnConnectionIdRetired( const QuicConnectionId& server_connection_id) { - QUICHE_DCHECK(support_multiple_cid_per_connection_); reference_counted_session_map_.erase(server_connection_id); } @@ -1158,9 +1129,8 @@ void QuicDispatcher::StatelesslyTerminateConnection( << ", error_code:" << error_code << ", error_details:" << error_details; time_wait_list_manager_->AddConnectionIdToTimeWait( - server_connection_id, action, - TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr, - {server_connection_id})); + action, TimeWaitConnectionInfo(format != GOOGLE_QUIC_PACKET, nullptr, + {server_connection_id})); return; } @@ -1198,7 +1168,7 @@ void QuicDispatcher::StatelesslyTerminateConnection( /*ietf_quic=*/format != GOOGLE_QUIC_PACKET, use_length_prefix, /*versions=*/{})); time_wait_list_manager()->AddConnectionIdToTimeWait( - server_connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, TimeWaitConnectionInfo(/*ietf_quic=*/format != GOOGLE_QUIC_PACKET, &termination_packets, {server_connection_id})); } @@ -1236,42 +1206,38 @@ void QuicDispatcher::ProcessBufferedChlos(size_t max_connections_to_create) { if (packets.empty()) { return; } + if (!packet_list.parsed_chlo.has_value()) { + QUIC_BUG(quic_dispatcher_no_parsed_chlo_in_buffered_packets) + << "Buffered connection has no CHLO. connection_id:" + << server_connection_id; + continue; + } + const ParsedClientHello& parsed_chlo = *packet_list.parsed_chlo; QuicConnectionId original_connection_id = server_connection_id; server_connection_id = MaybeReplaceServerConnectionId(server_connection_id, packet_list.version); - std::string alpn = SelectAlpn(packet_list.alpns); - std::unique_ptr session = - CreateQuicSession(server_connection_id, packets.front().self_address, - packets.front().peer_address, alpn, - packet_list.version, packet_list.sni); + std::string alpn = SelectAlpn(parsed_chlo.alpns); + std::unique_ptr session = CreateQuicSession( + server_connection_id, packets.front().self_address, + packets.front().peer_address, alpn, packet_list.version, parsed_chlo); if (original_connection_id != server_connection_id) { session->connection()->SetOriginalDestinationConnectionId( original_connection_id); } QUIC_DLOG(INFO) << "Created new session for " << server_connection_id; - if (use_reference_counted_session_map_) { - auto insertion_result = reference_counted_session_map_.insert( - std::make_pair(server_connection_id, - std::shared_ptr(std::move(session)))); - if (!insertion_result.second) { - QUIC_BUG(quic_bug_12724_5) - << "Tried to add a session to session_map with existing connection " - "id: " - << server_connection_id; - } else if (support_multiple_cid_per_connection_) { - ++num_sessions_in_session_map_; - } - DeliverPacketsToSession(packets, insertion_result.first->second.get()); - } else { - auto insertion_result = session_map_.insert( - std::make_pair(server_connection_id, std::move(session))); - QUIC_BUG_IF(quic_bug_12724_6, !insertion_result.second) + auto insertion_result = reference_counted_session_map_.insert( + std::make_pair(server_connection_id, + std::shared_ptr(std::move(session)))); + if (!insertion_result.second) { + QUIC_BUG(quic_bug_12724_5) << "Tried to add a session to session_map with existing connection " "id: " << server_connection_id; - DeliverPacketsToSession(packets, insertion_result.first->second.get()); + } else { + ++num_sessions_in_session_map_; } + DeliverPacketsToSession(packets, insertion_result.first->second.get()); } } @@ -1314,15 +1280,14 @@ void QuicDispatcher::BufferEarlyPacket(const ReceivedPacketInfo& packet_info) { EnqueuePacketResult rs = buffered_packets_.EnqueuePacket( packet_info.destination_connection_id, packet_info.form != GOOGLE_QUIC_PACKET, packet_info.packet, - packet_info.self_address, packet_info.peer_address, /*is_chlo=*/false, - /*alpns=*/{}, /*sni=*/absl::string_view(), packet_info.version); + packet_info.self_address, packet_info.peer_address, packet_info.version, + /*parsed_chlo=*/absl::nullopt); if (rs != EnqueuePacketResult::SUCCESS) { OnBufferPacketFailure(rs, packet_info.destination_connection_id); } } -void QuicDispatcher::ProcessChlo(const std::vector& alpns, - absl::string_view sni, +void QuicDispatcher::ProcessChlo(ParsedClientHello parsed_chlo, ReceivedPacketInfo* packet_info) { if (!buffered_packets_.HasBufferedPackets( packet_info->destination_connection_id) && @@ -1338,7 +1303,7 @@ void QuicDispatcher::ProcessChlo(const std::vector& alpns, packet_info->destination_connection_id, packet_info->form != GOOGLE_QUIC_PACKET, packet_info->packet, packet_info->self_address, packet_info->peer_address, - /*is_chlo=*/true, alpns, sni, packet_info->version); + packet_info->version, std::move(parsed_chlo)); if (rs != EnqueuePacketResult::SUCCESS) { OnBufferPacketFailure(rs, packet_info->destination_connection_id); } @@ -1350,10 +1315,10 @@ void QuicDispatcher::ProcessChlo(const std::vector& alpns, packet_info->destination_connection_id = MaybeReplaceServerConnectionId( original_connection_id, packet_info->version); // Creates a new session and process all buffered packets for this connection. - std::string alpn = SelectAlpn(alpns); + std::string alpn = SelectAlpn(parsed_chlo.alpns); std::unique_ptr session = CreateQuicSession( packet_info->destination_connection_id, packet_info->self_address, - packet_info->peer_address, alpn, packet_info->version, sni); + packet_info->peer_address, alpn, packet_info->version, parsed_chlo); if (QUIC_PREDICT_FALSE(session == nullptr)) { QUIC_BUG(quic_bug_10287_8) << "CreateQuicSession returned nullptr for " @@ -1370,28 +1335,18 @@ void QuicDispatcher::ProcessChlo(const std::vector& alpns, << packet_info->destination_connection_id; QuicSession* session_ptr; - if (use_reference_counted_session_map_) { - auto insertion_result = - reference_counted_session_map_.insert(std::make_pair( - packet_info->destination_connection_id, - std::shared_ptr(std::move(session.release())))); - if (!insertion_result.second) { - QUIC_BUG(quic_bug_10287_9) - << "Tried to add a session to session_map with existing " - "connection id: " - << packet_info->destination_connection_id; - } else if (support_multiple_cid_per_connection_) { - ++num_sessions_in_session_map_; - } - session_ptr = insertion_result.first->second.get(); - } else { - auto insertion_result = session_map_.insert(std::make_pair( - packet_info->destination_connection_id, std::move(session))); - QUIC_BUG_IF(quic_bug_12724_8, !insertion_result.second) - << "Tried to add a session to session_map with existing connection id: " + auto insertion_result = reference_counted_session_map_.insert(std::make_pair( + packet_info->destination_connection_id, + std::shared_ptr(std::move(session.release())))); + if (!insertion_result.second) { + QUIC_BUG(quic_bug_10287_9) + << "Tried to add a session to session_map with existing " + "connection id: " << packet_info->destination_connection_id; - session_ptr = insertion_result.first->second.get(); + } else { + ++num_sessions_in_session_map_; } + session_ptr = insertion_result.first->second.get(); std::list packets = buffered_packets_.DeliverPackets(packet_info->destination_connection_id) .buffered_packets; @@ -1449,8 +1404,14 @@ bool QuicDispatcher::IsSupportedVersion(const ParsedQuicVersion version) { void QuicDispatcher::MaybeResetPacketsWithNoVersion( const ReceivedPacketInfo& packet_info) { QUICHE_DCHECK(!packet_info.version_flag); - if (GetQuicReloadableFlag(quic_fix_stateless_reset) && - packet_info.form != GOOGLE_QUIC_PACKET) { + // Do not send a stateless reset if a reset has been sent to this address + // recently. + if (recent_stateless_reset_addresses_.contains(packet_info.peer_address)) { + QUIC_CODE_COUNT(quic_donot_send_reset_repeatedly); + QUICHE_DCHECK(use_recent_reset_addresses_); + return; + } + if (packet_info.form != GOOGLE_QUIC_PACKET) { // Drop IETF packets smaller than the minimal stateless reset length. if (packet_info.packet.length() <= QuicFramer::GetMinStatelessResetPacketLength()) { @@ -1466,8 +1427,24 @@ void QuicDispatcher::MaybeResetPacketsWithNoVersion( QUIC_CODE_COUNT(drop_too_small_packets); return; } - // TODO(fayang): Consider rate limiting reset packets if reset packet size > - // packet_length. + } + if (use_recent_reset_addresses_) { + QUIC_RESTART_FLAG_COUNT(quic_use_recent_reset_addresses); + // Do not send a stateless reset if there are too many stateless reset + // addresses. + if (recent_stateless_reset_addresses_.size() >= + GetQuicFlag(FLAGS_quic_max_recent_stateless_reset_addresses)) { + QUIC_CODE_COUNT(quic_too_many_recent_reset_addresses); + return; + } + if (recent_stateless_reset_addresses_.empty()) { + clear_stateless_reset_addresses_alarm_->Update( + helper()->GetClock()->ApproximateNow() + + QuicTime::Delta::FromMilliseconds(GetQuicFlag( + FLAGS_quic_recent_stateless_reset_addresses_lifetime_ms)), + QuicTime::Delta::Zero()); + } + recent_stateless_reset_addresses_.emplace(packet_info.peer_address); } time_wait_list_manager()->SendPublicReset( @@ -1478,12 +1455,7 @@ void QuicDispatcher::MaybeResetPacketsWithNoVersion( } size_t QuicDispatcher::NumSessions() const { - if (support_multiple_cid_per_connection_) { - return num_sessions_in_session_map_; - } - return use_reference_counted_session_map_ - ? reference_counted_session_map_.size() - : session_map_.size(); + return num_sessions_in_session_map_; } } // namespace quic diff --git a/gquiche/quic/core/quic_dispatcher.h b/gquiche/quic/core/quic_dispatcher.h index 2c072ec4..dde30126 100644 --- a/gquiche/quic/core/quic_dispatcher.h +++ b/gquiche/quic/core/quic_dispatcher.h @@ -27,9 +27,9 @@ #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_time_wait_list_manager.h" #include "gquiche/quic/core/quic_version_manager.h" -#include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_reference_counted.h" #include "gquiche/quic/platform/api/quic_socket_address.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { namespace test { @@ -45,7 +45,8 @@ class QUIC_NO_EXPORT QuicDispatcher public QuicBufferedPacketStore::VisitorInterface { public: // Ideally we'd have a linked_hash_set: the boolean is unused. - using WriteBlockedList = QuicLinkedHashMap; + using WriteBlockedList = + quiche::QuicheLinkedHashMap; QuicDispatcher( const QuicConfig* config, @@ -120,10 +121,6 @@ class QUIC_NO_EXPORT QuicDispatcher void OnConnectionAddedToTimeWaitList( QuicConnectionId server_connection_id) override; - using SessionMap = absl::flat_hash_map, - QuicConnectionIdHash>; - using ReferenceCountedSessionMap = absl::flat_hash_map, @@ -131,11 +128,12 @@ class QUIC_NO_EXPORT QuicDispatcher size_t NumSessions() const; - const SessionMap& session_map() const { return session_map_; } - // Deletes all sessions on the closed session list and clears the list. virtual void DeleteSessions(); + // Clear recent_stateless_reset_addresses_. + void ClearStatelessResetAddresses(); + using ConnectionIdMap = absl:: flat_hash_map; @@ -166,22 +164,15 @@ class QUIC_NO_EXPORT QuicDispatcher bool accept_new_connections() const { return accept_new_connections_; } - bool use_reference_counted_session_map() const { - return use_reference_counted_session_map_; - } - - bool support_multiple_cid_per_connection() const { - return support_multiple_cid_per_connection_; - } - protected: + // Creates a QUIC session based on the given information. + // |alpn| is the selected ALPN from |parsed_chlo.alpns|. virtual std::unique_ptr CreateQuicSession( QuicConnectionId server_connection_id, const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view alpn, + const QuicSocketAddress& peer_address, absl::string_view alpn, const ParsedQuicVersion& version, - absl::string_view sni) = 0; + const ParsedClientHello& parsed_chlo) = 0; // Tries to validate and dispatch packet based on available information. // Returns true if packet is dropped or successfully dispatched (e.g., @@ -233,6 +224,15 @@ class QUIC_NO_EXPORT QuicDispatcher // TODO(fayang): Merge ValidityChecks into MaybeDispatchPacket. virtual QuicPacketFate ValidityChecks(const ReceivedPacketInfo& packet_info); + // Extra validity checks after the full Client Hello is parsed, this allows + // subclasses to reject a connection based on sni or alpn. + // Only called if ValidityChecks returns kFateProcess. + virtual QuicPacketFate ValidityChecksOnFullChlo( + const ReceivedPacketInfo& /*packet_info*/, + const ParsedClientHello& /*parsed_chlo*/) const { + return kFateProcess; + } + // Create and return the time wait list manager for this dispatcher, which // will be owned by the dispatcher as time_wait_list_manager_ virtual QuicTimeWaitListManager* CreateQuicTimeWaitListManager(); @@ -240,10 +240,10 @@ class QUIC_NO_EXPORT QuicDispatcher // Buffers packet until it can be delivered to a connection. void BufferEarlyPacket(const ReceivedPacketInfo& packet_info); - // Called when |packet_info| is a CHLO packet. Creates a new connection and - // delivers any buffered packets for that connection id. - void ProcessChlo(const std::vector& alpns, - absl::string_view sni, + // Called when |packet_info| is the last received packet of the client hello. + // |parsed_chlo| is the parsed version of the client hello. Creates a new + // connection and delivers any buffered packets for that connection id. + void ProcessChlo(ParsedClientHello parsed_chlo, ReceivedPacketInfo* packet_info); // Return true if dispatcher wants to destroy session outside of @@ -272,6 +272,10 @@ class QUIC_NO_EXPORT QuicDispatcher return session_helper_.get(); } + const QuicCryptoServerStreamBase::Helper* session_helper() const { + return session_helper_.get(); + } + QuicAlarmFactory* alarm_factory() { return alarm_factory_.get(); } QuicPacketWriter* writer() { return writer_.get(); } @@ -375,6 +379,17 @@ class QUIC_NO_EXPORT QuicDispatcher // ProcessValidatedPacketWithUnknownConnectionId. void ProcessHeader(ReceivedPacketInfo* packet_info); + // Try to extract information(sni, alpns, ...) if the full Client Hello has + // been parsed. + // + // Return the parsed client hello if the full Client Hello has been + // successfully parsed. + // + // Otherwise return absl::nullopt and either buffer or (rarely) drop the + // packet. + absl::optional TryExtractChloOrBufferEarlyPacket( + const ReceivedPacketInfo& packet_info); + // Deliver |packets| to |session| for further processing. void DeliverPacketsToSession( const std::list& packets, @@ -393,7 +408,6 @@ class QUIC_NO_EXPORT QuicDispatcher // The list of connections waiting to write. WriteBlockedList write_blocked_list_; - SessionMap session_map_; ReferenceCountedSessionMap reference_counted_session_map_; // Entity that manages connection_ids in time wait state. @@ -453,17 +467,19 @@ class QUIC_NO_EXPORT QuicDispatcher // version does not allow variable length connection ID. uint8_t expected_server_connection_id_length_; + // Records client addresses that have been recently reset. + absl::flat_hash_set + recent_stateless_reset_addresses_; + + // An alarm which clear recent_stateless_reset_addresses_. + std::unique_ptr clear_stateless_reset_addresses_alarm_; + // If true, change expected_server_connection_id_length_ to be the received // destination connection ID length of all IETF long headers. bool should_update_expected_server_connection_id_length_; - const bool use_reference_counted_session_map_ = - GetQuicRestartFlag(quic_use_reference_counted_sesssion_map); - const bool support_multiple_cid_per_connection_ = - use_reference_counted_session_map_ && - GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid_v2) && - GetQuicRestartFlag( - quic_dispatcher_support_multiple_cid_per_connection_v2); + const bool use_recent_reset_addresses_ = + GetQuicRestartFlag(quic_use_recent_reset_addresses); }; } // namespace quic diff --git a/gquiche/quic/core/quic_dispatcher_test.cc b/gquiche/quic/core/quic_dispatcher_test.cc index 41a9d1f0..f6cd8ffd 100644 --- a/gquiche/quic/core/quic_dispatcher_test.cc +++ b/gquiche/quic/core/quic_dispatcher_test.cc @@ -68,12 +68,8 @@ class TestQuicSpdyServerSession : public QuicServerSessionBase { QuicConnection* connection, const QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache) - : QuicServerSessionBase(config, - CurrentSupportedVersions(), - connection, - nullptr, - nullptr, - crypto_config, + : QuicServerSessionBase(config, CurrentSupportedVersions(), connection, + nullptr, nullptr, crypto_config, compressed_certs_cache) { Initialize(); } @@ -83,26 +79,17 @@ class TestQuicSpdyServerSession : public QuicServerSessionBase { ~TestQuicSpdyServerSession() override { DeleteConnection(); } - MOCK_METHOD(void, - OnConnectionClosed, + MOCK_METHOD(void, OnConnectionClosed, (const QuicConnectionCloseFrame& frame, ConnectionCloseSource source), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (QuicStreamId id), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (PendingStream*), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingBidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingUnidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), (override)); std::unique_ptr CreateQuicCryptoServerStream( @@ -121,32 +108,25 @@ class TestDispatcher : public QuicDispatcher { public: TestDispatcher(const QuicConfig* config, const QuicCryptoServerConfig* crypto_config, - QuicVersionManager* version_manager, - QuicRandom* random) - : QuicDispatcher(config, - crypto_config, - version_manager, + QuicVersionManager* version_manager, QuicRandom* random) + : QuicDispatcher(config, crypto_config, version_manager, std::make_unique(), std::unique_ptr( new QuicSimpleCryptoServerStreamHelper()), - std::make_unique(), + std::make_unique(), kQuicDefaultConnectionIdLength), random_(random) {} - MOCK_METHOD(std::unique_ptr, - CreateQuicSession, + MOCK_METHOD(std::unique_ptr, CreateQuicSession, (QuicConnectionId connection_id, const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view alpn, - const quic::ParsedQuicVersion& version, - absl::string_view sni), + const QuicSocketAddress& peer_address, absl::string_view alpn, + const ParsedQuicVersion& version, + const ParsedClientHello& parsed_chlo), (override)); - MOCK_METHOD(bool, - ShouldCreateOrBufferPacketForConnection, - (const ReceivedPacketInfo& packet_info), - (override)); + MOCK_METHOD(bool, ShouldCreateOrBufferPacketForConnection, + (const ReceivedPacketInfo& packet_info), (override)); struct TestQuicPerPacketContext : public QuicPerPacketContext { std::string custom_packet_context; @@ -167,6 +147,7 @@ class TestDispatcher : public QuicDispatcher { std::string custom_packet_context_; + using QuicDispatcher::MaybeDispatchPacket; using QuicDispatcher::SetAllowShortInitialServerConnectionIds; using QuicDispatcher::writer; @@ -183,9 +164,7 @@ class MockServerConnection : public MockQuicConnection { MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, QuicDispatcher* dispatcher) - : MockQuicConnection(connection_id, - helper, - alarm_factory, + : MockQuicConnection(connection_id, helper, alarm_factory, Perspective::IS_SERVER), dispatcher_(dispatcher), active_connection_ids_({connection_id}) {} @@ -229,15 +208,12 @@ class QuicDispatcherTestBase : public QuicTestWithParam { : version_(GetParam()), version_manager_(AllSupportedVersions()), crypto_config_(QuicCryptoServerConfig::TESTING, - QuicRandom::GetInstance(), - std::move(proof_source), + QuicRandom::GetInstance(), std::move(proof_source), KeyExchangeSource::Default()), server_address_(QuicIpAddress::Any4(), 5), - dispatcher_( - new NiceMock(&config_, - &crypto_config_, - &version_manager_, - mock_helper_.GetRandomGenerator())), + dispatcher_(new NiceMock( + &config_, &crypto_config_, &version_manager_, + mock_helper_.GetRandomGenerator())), time_wait_list_manager_(nullptr), session1_(nullptr), session2_(nullptr), @@ -272,8 +248,7 @@ class QuicDispatcherTestBase : public QuicTestWithParam { // using the version under test. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - const std::string& data) { + bool has_version_flag, const std::string& data) { ProcessPacket(peer_address, server_connection_id, has_version_flag, data, CONNECTION_ID_PRESENT, PACKET_4BYTE_PACKET_NUMBER); } @@ -282,8 +257,7 @@ class QuicDispatcherTestBase : public QuicTestWithParam { // using the version under test. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - const std::string& data, + bool has_version_flag, const std::string& data, QuicConnectionIdIncluded server_connection_id_included, QuicPacketNumberLength packet_number_length) { ProcessPacket(peer_address, server_connection_id, has_version_flag, data, @@ -293,8 +267,7 @@ class QuicDispatcherTestBase : public QuicTestWithParam { // Process a packet using the version under test. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - const std::string& data, + bool has_version_flag, const std::string& data, QuicConnectionIdIncluded server_connection_id_included, QuicPacketNumberLength packet_number_length, uint64_t packet_number) { @@ -306,10 +279,8 @@ class QuicDispatcherTestBase : public QuicTestWithParam { // Processes a packet. void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, - bool has_version_flag, - ParsedQuicVersion version, - const std::string& data, - bool full_padding, + bool has_version_flag, ParsedQuicVersion version, + const std::string& data, bool full_padding, QuicConnectionIdIncluded server_connection_id_included, QuicPacketNumberLength packet_number_length, uint64_t packet_number) { @@ -323,10 +294,8 @@ class QuicDispatcherTestBase : public QuicTestWithParam { void ProcessPacket(QuicSocketAddress peer_address, QuicConnectionId server_connection_id, QuicConnectionId client_connection_id, - bool has_version_flag, - ParsedQuicVersion version, - const std::string& data, - bool full_padding, + bool has_version_flag, ParsedQuicVersion version, + const std::string& data, bool full_padding, QuicConnectionIdIncluded server_connection_id_included, QuicConnectionIdIncluded client_connection_id_included, QuicPacketNumberLength packet_number_length, @@ -344,8 +313,7 @@ class QuicDispatcherTestBase : public QuicTestWithParam { void ProcessReceivedPacket( std::unique_ptr received_packet, - const QuicSocketAddress& peer_address, - const ParsedQuicVersion& version, + const QuicSocketAddress& peer_address, const ParsedQuicVersion& version, const QuicConnectionId& server_connection_id) { if (version.UsesQuicCrypto() && ChloExtractor::Extract(*received_packet, version, {}, nullptr, @@ -371,12 +339,9 @@ class QuicDispatcherTestBase : public QuicTestWithParam { } std::unique_ptr CreateSession( - TestDispatcher* dispatcher, - const QuicConfig& config, - QuicConnectionId connection_id, - const QuicSocketAddress& /*peer_address*/, - MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, + TestDispatcher* dispatcher, const QuicConfig& config, + QuicConnectionId connection_id, const QuicSocketAddress& /*peer_address*/, + MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, const QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache, TestQuicSpdyServerSession** session_ptr) { @@ -418,8 +383,7 @@ class QuicDispatcherTestBase : public QuicTestWithParam { } void ProcessUndecryptableEarlyPacket( - const ParsedQuicVersion& version, - const QuicSocketAddress& peer_address, + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, const QuicConnectionId& server_connection_id) { std::unique_ptr encrypted_packet = GetUndecryptableEarlyPacket(version, server_connection_id); @@ -445,21 +409,58 @@ class QuicDispatcherTestBase : public QuicTestWithParam { const QuicSocketAddress& peer_address, const QuicConnectionId& server_connection_id, const QuicConnectionId& client_connection_id) { + ProcessFirstFlight(version, peer_address, server_connection_id, + client_connection_id, TestClientCryptoConfig()); + } + + void ProcessFirstFlight( + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr client_crypto_config) { std::vector> packets = - GetFirstFlightOfPackets(version, server_connection_id, - client_connection_id); + GetFirstFlightOfPackets(version, DefaultQuicConfig(), + server_connection_id, client_connection_id, + std::move(client_crypto_config)); for (auto&& packet : packets) { ProcessReceivedPacket(std::move(packet), peer_address, version, server_connection_id); } } + std::unique_ptr TestClientCryptoConfig() { + auto client_crypto_config = std::make_unique( + crypto_test_utils::ProofVerifierForTesting()); + if (address_token_.has_value()) { + client_crypto_config->LookupOrCreate(TestServerId()) + ->set_source_address_token(*address_token_); + } + return client_crypto_config; + } + + // If called, the first flight packets generated in |ProcessFirstFlight| will + // contain the given |address_token|. + void SetAddressToken(std::string address_token) { + address_token_ = std::move(address_token); + } + std::string ExpectedAlpnForVersion(ParsedQuicVersion version) { return AlpnForVersion(version); } std::string ExpectedAlpn() { return ExpectedAlpnForVersion(version_); } + ParsedClientHello ParsedClientHelloForTest() { + ParsedClientHello parsed_chlo; + parsed_chlo.alpns = {ExpectedAlpn()}; + parsed_chlo.sni = TestHostname(); + if (address_token_.has_value() && + !GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { + parsed_chlo.retry_token = *address_token_; + } + return parsed_chlo; + } + void MarkSession1Deleted() { session1_ = nullptr; } void VerifyVersionSupported(ParsedQuicVersion version) { @@ -499,6 +500,11 @@ class QuicDispatcherTestBase : public QuicTestWithParam { const QuicConnectionId& server_connection_id, const QuicConnectionId& client_connection_id); + TestAlarmFactory::TestAlarm* GetClearResetAddressesAlarm() { + return reinterpret_cast( + QuicDispatcherPeer::GetClearResetAddressesAlarm(dispatcher_.get())); + } + ParsedQuicVersion version_; MockQuicConnectionHelper mock_helper_; MockAlarmFactory mock_alarm_factory_; @@ -513,6 +519,7 @@ class QuicDispatcherTestBase : public QuicTestWithParam { std::map> data_connection_map_; QuicBufferedPacketStore* store_; uint64_t connection_id_; + absl::optional address_token_; }; class QuicDispatcherTestAllVersions : public QuicDispatcherTestBase {}; @@ -532,11 +539,14 @@ TEST_P(QuicDispatcherTestAllVersions, TlsClientHelloCreatesSession) { if (version_.UsesQuicCrypto()) { return; } + SetAddressToken("hsdifghdsaifnasdpfjdsk"); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); - EXPECT_CALL(*dispatcher_, - CreateQuicSession(TestConnectionId(1), _, client_address, - Eq(ExpectedAlpn()), _, _)) + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, TestConnectionId(1), client_address, &mock_helper_, &mock_alarm_factory_, &crypto_config_, @@ -558,6 +568,8 @@ void QuicDispatcherTestBase::TestTlsMultiPacketClientHello( if (!version_.UsesTls()) { return; } + SetAddressToken("857293462398"); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); QuicConnectionId server_connection_id = TestConnectionId(); QuicConfig client_config = DefaultQuicConfig(); @@ -568,7 +580,9 @@ void QuicDispatcherTestBase::TestTlsMultiPacketClientHello( client_config.custom_transport_parameters_to_send()[kCustomParameterId] = kCustomParameterValue; std::vector> packets = - GetFirstFlightOfPackets(version_, client_config, server_connection_id); + GetFirstFlightOfPackets(version_, client_config, server_connection_id, + EmptyQuicConnectionId(), + TestClientCryptoConfig()); ASSERT_EQ(packets.size(), 2u); if (add_reordering) { std::swap(packets[0], packets[1]); @@ -585,9 +599,10 @@ void QuicDispatcherTestBase::TestTlsMultiPacketClientHello( << "No session should be created before the rest of the CHLO arrives."; // Processing the second packet should create the new session. - EXPECT_CALL(*dispatcher_, - CreateQuicSession(server_connection_id, _, client_address, - Eq(ExpectedAlpn()), _, TestHostname())) + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(server_connection_id, _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, server_connection_id, client_address, &mock_helper_, &mock_alarm_factory_, &crypto_config_, @@ -632,15 +647,14 @@ TEST_P(QuicDispatcherTestAllVersions, LegacyVersionEncapsulation) { QuicVersionLabel version_label; ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported(); QuicConnectionId destination_connection_id, source_connection_id; - bool retry_token_present; - absl::string_view retry_token; + absl::optional retry_token; std::string detailed_error; const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( QuicEncryptedPacket(packets[0]->data(), packets[0]->length()), kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_present, &has_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); ASSERT_THAT(error, IsQuicNoError()) << detailed_error; EXPECT_EQ(format, GOOGLE_QUIC_PACKET); EXPECT_TRUE(version_present); @@ -648,7 +662,7 @@ TEST_P(QuicDispatcherTestAllVersions, LegacyVersionEncapsulation) { EXPECT_EQ(parsed_version, LegacyVersionForEncapsulation()); EXPECT_EQ(destination_connection_id, server_connection_id); EXPECT_EQ(source_connection_id, EmptyQuicConnectionId()); - EXPECT_FALSE(retry_token_present); + EXPECT_FALSE(retry_token.has_value()); EXPECT_TRUE(detailed_error.empty()); // Processing the packet should create a new session. @@ -677,9 +691,10 @@ TEST_P(QuicDispatcherTestAllVersions, LegacyVersionEncapsulation) { TEST_P(QuicDispatcherTestAllVersions, ProcessPackets) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); - EXPECT_CALL(*dispatcher_, - CreateQuicSession(TestConnectionId(1), _, client_address, - Eq(ExpectedAlpn()), _, TestHostname())) + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, TestConnectionId(1), client_address, &mock_helper_, &mock_alarm_factory_, &crypto_config_, @@ -694,9 +709,10 @@ TEST_P(QuicDispatcherTestAllVersions, ProcessPackets) { ReceivedPacketInfoConnectionIdEquals(TestConnectionId(1)))); ProcessFirstFlight(client_address, TestConnectionId(1)); - EXPECT_CALL(*dispatcher_, - CreateQuicSession(TestConnectionId(2), _, client_address, - Eq(ExpectedAlpn()), _, TestHostname())) + EXPECT_CALL( + *dispatcher_, + CreateQuicSession(TestConnectionId(2), _, client_address, + Eq(ExpectedAlpn()), _, Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, TestConnectionId(2), client_address, &mock_helper_, &mock_alarm_factory_, &crypto_config_, @@ -894,7 +910,7 @@ TEST_P(QuicDispatcherTestAllVersions, TimeWaitListManager) { EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, connection_id, _, _, _)) .Times(1); - EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) .Times(0); ProcessPacket(client_address, connection_id, true, "data"); } @@ -910,7 +926,7 @@ TEST_P(QuicDispatcherTestAllVersions, NoVersionPacketToTimeWaitListManager) { EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, connection_id, _, _, _)) .Times(0); - EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) .Times(0); EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) .Times(1); @@ -923,14 +939,16 @@ TEST_P(QuicDispatcherTestAllVersions, QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char short_packet[21] = {0x70, 0xa7, 0x02, 0x6b}; - QuicReceivedPacket packet(short_packet, 21, QuicTime::Zero()); - char valid_size_packet[23] = {0x70, 0xa7, 0x02, 0x6c}; - QuicReceivedPacket packet2(valid_size_packet, 23, QuicTime::Zero()); + uint8_t short_packet[21] = {0x70, 0xa7, 0x02, 0x6b}; + QuicReceivedPacket packet(reinterpret_cast(short_packet), 21, + QuicTime::Zero()); + uint8_t valid_size_packet[23] = {0x70, 0xa7, 0x02, 0x6c}; + QuicReceivedPacket packet2(reinterpret_cast(valid_size_packet), 23, + QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) .Times(0); - EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) .Times(0); // Verify small packet is silently dropped. EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) @@ -941,6 +959,120 @@ TEST_P(QuicDispatcherTestAllVersions, dispatcher_->ProcessPacket(server_address_, client_address, packet2); } +TEST_P(QuicDispatcherTestOneVersion, DropPacketWithInvalidFlags) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t all_zero_packet[1200] = {}; + QuicReceivedPacket packet(reinterpret_cast(all_zero_packet), + sizeof(all_zero_packet), QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(0); + dispatcher_->ProcessPacket(server_address_, client_address, packet); +} + +TEST_P(QuicDispatcherTestAllVersions, LimitResetsToSameClientAddress) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress client_address2(QuicIpAddress::Loopback4(), 2); + QuicSocketAddress client_address3(QuicIpAddress::Loopback6(), 1); + QuicConnectionId connection_id = TestConnectionId(1); + + if (GetQuicRestartFlag(quic_use_recent_reset_addresses)) { + // Verify only one reset is sent to the address, although multiple packets + // are received. + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(1); + } else { + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(3); + } + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data2"); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data3"); + + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(2); + ProcessPacket(client_address2, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address3, connection_id, /*has_version_flag=*/false, + "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, + StopSendingResetOnTooManyRecentAddresses) { + SetQuicFlag(FLAGS_quic_max_recent_stateless_reset_addresses, 2); + const size_t kTestLifeTimeMs = 10; + SetQuicFlag(FLAGS_quic_recent_stateless_reset_addresses_lifetime_ms, + kTestLifeTimeMs); + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + QuicSocketAddress client_address2(QuicIpAddress::Loopback4(), 2); + QuicSocketAddress client_address3(QuicIpAddress::Loopback6(), 1); + QuicConnectionId connection_id = TestConnectionId(1); + + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(2); + EXPECT_FALSE(GetClearResetAddressesAlarm()->IsSet()); + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); + const QuicTime expected_deadline = + mock_helper_.GetClock()->Now() + + QuicTime::Delta::FromMilliseconds(kTestLifeTimeMs); + if (GetQuicRestartFlag(quic_use_recent_reset_addresses)) { + ASSERT_TRUE(GetClearResetAddressesAlarm()->IsSet()); + EXPECT_EQ(expected_deadline, GetClearResetAddressesAlarm()->deadline()); + } else { + EXPECT_FALSE(GetClearResetAddressesAlarm()->IsSet()); + } + // Received no version packet 2 after 5ms. + mock_helper_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + ProcessPacket(client_address2, connection_id, /*has_version_flag=*/false, + "data"); + if (GetQuicRestartFlag(quic_use_recent_reset_addresses)) { + ASSERT_TRUE(GetClearResetAddressesAlarm()->IsSet()); + // Verify deadline does not change. + EXPECT_EQ(expected_deadline, GetClearResetAddressesAlarm()->deadline()); + } else { + EXPECT_FALSE(GetClearResetAddressesAlarm()->IsSet()); + } + if (GetQuicRestartFlag(quic_use_recent_reset_addresses)) { + // Verify reset gets throttled since there are too many recent addresses. + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(0); + } else { + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(1); + } + ProcessPacket(client_address3, connection_id, /*has_version_flag=*/false, + "data"); + + mock_helper_.AdvanceTime(QuicTime::Delta::FromMilliseconds(5)); + if (GetQuicRestartFlag(quic_use_recent_reset_addresses)) { + GetClearResetAddressesAlarm()->Fire(); + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(2); + } else { + EXPECT_CALL(*time_wait_list_manager_, SendPublicReset(_, _, _, _, _, _)) + .Times(3); + } + ProcessPacket(client_address, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address2, connection_id, /*has_version_flag=*/false, + "data"); + ProcessPacket(client_address3, connection_id, /*has_version_flag=*/false, + "data"); +} + // Makes sure nine-byte connection IDs are replaced by 8-byte ones. TEST_P(QuicDispatcherTestAllVersions, LongConnectionIdLengthReplaced) { if (!version_.AllowsVariableLengthConnectionIds()) { @@ -1077,12 +1209,48 @@ TEST_P(QuicDispatcherTestAllVersions, ProcessPacketWithZeroPort) { .Times(0); EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) .Times(0); - EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + ProcessPacket(client_address, TestConnectionId(1), /*has_version_flag=*/true, + "data"); +} + +TEST_P(QuicDispatcherTestAllVersions, ProcessPacketWithBlockedPort) { + CreateTimeWaitListManager(); + + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 17); + + // dispatcher_ should drop this packet. + EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionId(1), _, + client_address, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) .Times(0); ProcessPacket(client_address, TestConnectionId(1), /*has_version_flag=*/true, "data"); } +TEST_P(QuicDispatcherTestAllVersions, ProcessPacketWithNonBlockedPort) { + CreateTimeWaitListManager(); + + // Port 443 must not be blocked because it might be useful for proxies to send + // proxied traffic with source port 443 as that allows building a full QUIC + // proxy using a single UDP socket. + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 443); + + // dispatcher_ should not drop this packet. + EXPECT_CALL(*dispatcher_, + CreateQuicSession(TestConnectionId(1), _, client_address, + Eq(ExpectedAlpn()), _, _)) + .WillOnce(Return(ByMove(CreateSession( + dispatcher_.get(), config_, TestConnectionId(1), client_address, + &mock_helper_, &mock_alarm_factory_, &crypto_config_, + QuicDispatcherPeer::GetCache(dispatcher_.get()), &session1_)))); + ProcessFirstFlight(client_address, TestConnectionId(1)); +} + TEST_P(QuicDispatcherTestAllVersions, DropPacketWithKnownVersionAndInvalidShortInitialConnectionId) { if (!version_.AllowsVariableLengthConnectionIds()) { @@ -1096,11 +1264,35 @@ TEST_P(QuicDispatcherTestAllVersions, EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) .Times(0); - EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) .Times(0); ProcessFirstFlight(client_address, EmptyQuicConnectionId()); } +TEST_P(QuicDispatcherTestAllVersions, + DropPacketWithKnownVersionAndInvalidInitialConnectionId) { + CreateTimeWaitListManager(); + + QuicSocketAddress server_address; + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + + // dispatcher_ should drop this packet with invalid connection ID. + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) + .Times(0); + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) + .Times(0); + absl::string_view cid_str = "123456789abcdefg123456789abcdefg"; + QuicConnectionId invalid_connection_id(cid_str.data(), cid_str.length()); + QuicReceivedPacket packet("packet", 6, QuicTime::Zero()); + ReceivedPacketInfo packet_info(server_address, client_address, packet); + packet_info.version_flag = true; + packet_info.version = version_; + packet_info.destination_connection_id = invalid_connection_id; + + ASSERT_TRUE(dispatcher_->MaybeDispatchPacket(packet_info)); +} + void QuicDispatcherTestBase:: TestVersionNegotiationForUnknownVersionInvalidShortInitialConnectionId( const QuicConnectionId& server_connection_id, @@ -1160,10 +1352,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionDraft28WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 0xFF, 0x00, 0x00, 28, /*destination connection ID length*/ 0x08}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1177,10 +1369,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionDraft27WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 0xFF, 0x00, 0x00, 27, /*destination connection ID length*/ 0x08}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1194,10 +1386,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionDraft25WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 0xFF, 0x00, 0x00, 25, /*destination connection ID length*/ 0x08}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1211,10 +1403,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionT050WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 'T', '0', '5', '0', /*destination connection ID length*/ 0x08}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1228,10 +1420,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionQ049WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 'Q', '0', '4', '9', /*destination connection ID length*/ 0x08}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1245,10 +1437,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionQ048WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 'Q', '0', '4', '8', /*connection ID length byte*/ 0x50}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1262,10 +1454,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionQ047WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 'Q', '0', '4', '7', /*connection ID length byte*/ 0x50}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1279,10 +1471,10 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionQ045WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { 0xC0, 'Q', '0', '4', '5', /*connection ID length byte*/ 0x50}; - QuicReceivedPacket received_packet(packet, ABSL_ARRAYSIZE(packet), - QuicTime::Zero()); + QuicReceivedPacket received_packet(reinterpret_cast(packet), + ABSL_ARRAYSIZE(packet), QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1296,10 +1488,11 @@ TEST_P(QuicDispatcherTestOneVersion, RejectDeprecatedVersionQ044WithVersionNegotiation) { QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); CreateTimeWaitListManager(); - char packet44[kMinPacketSizeForVersionNegotiation] = { + uint8_t packet44[kMinPacketSizeForVersionNegotiation] = { 0xFF, 'Q', '0', '4', '4', /*connection ID length byte*/ 0x50}; - QuicReceivedPacket received_packet44( - packet44, kMinPacketSizeForVersionNegotiation, QuicTime::Zero()); + QuicReceivedPacket received_packet44(reinterpret_cast(packet44), + kMinPacketSizeForVersionNegotiation, + QuicTime::Zero()); EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL( *time_wait_list_manager_, @@ -1310,7 +1503,25 @@ TEST_P(QuicDispatcherTestOneVersion, received_packet44); } -static_assert(quic::SupportedVersions().size() == 6u, +TEST_P(QuicDispatcherTestOneVersion, + RejectDeprecatedVersionT051WithVersionNegotiation) { + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 1); + CreateTimeWaitListManager(); + uint8_t packet[kMinPacketSizeForVersionNegotiation] = { + 0xFF, 'T', '0', '5', '1', /*destination connection ID length*/ 0x08}; + QuicReceivedPacket received_packet(reinterpret_cast(packet), + kMinPacketSizeForVersionNegotiation, + QuicTime::Zero()); + EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); + EXPECT_CALL( + *time_wait_list_manager_, + SendVersionNegotiationPacket(_, _, /*ietf_quic=*/true, + /*use_length_prefix=*/true, _, _, _, _)) + .Times(1); + dispatcher_->ProcessPacket(server_address_, client_address, received_packet); +} + +static_assert(quic::SupportedVersions().size() == 5u, "Please add new RejectDeprecatedVersion tests above this assert " "when deprecating versions"); @@ -1345,8 +1556,7 @@ class SavingWriter : public QuicPacketWriterWrapper { public: bool IsWriteBlocked() const override { return false; } - WriteResult WritePacket(const char* buffer, - size_t buf_len, + WriteResult WritePacket(const char* buffer, size_t buf_len, const QuicIpAddress& /*self_client_address*/, const QuicSocketAddress& /*peer_client_address*/, PerPacketOptions* /*options*/) override { @@ -1509,7 +1719,7 @@ TEST_P(QuicDispatcherTestAllVersions, DoNotProcessSmallPacket) { EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL(*time_wait_list_manager_, SendPacket(_, _, _)).Times(0); - EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) .Times(0); ProcessPacket(client_address, TestConnectionId(1), /*has_version_flag=*/true, version_, SerializeCHLO(), /*full_padding=*/false, @@ -1523,7 +1733,7 @@ TEST_P(QuicDispatcherTestAllVersions, ProcessSmallCoalescedPacket) { EXPECT_CALL(*time_wait_list_manager_, SendPacket(_, _, _)).Times(0); // clang-format off - char coalesced_packet[1200] = { + uint8_t coalesced_packet[1200] = { // first coalesced packet // public flags (long header with packet type INITIAL and // 4-byte packet number) @@ -1560,7 +1770,8 @@ TEST_P(QuicDispatcherTestAllVersions, ProcessSmallCoalescedPacket) { 0x12, 0x34, 0x56, 0x79, }; // clang-format on - QuicReceivedPacket packet(coalesced_packet, 1200, QuicTime::Zero()); + QuicReceivedPacket packet(reinterpret_cast(coalesced_packet), 1200, + QuicTime::Zero()); dispatcher_->ProcessPacket(server_address_, client_address, packet); } @@ -1661,7 +1872,7 @@ TEST_P(QuicDispatcherTestStrayPacketConnectionId, EXPECT_CALL(*dispatcher_, CreateQuicSession(_, _, _, _, _, _)).Times(0); EXPECT_CALL(*time_wait_list_manager_, ProcessPacket(_, _, _, _, _, _)) .Times(0); - EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*time_wait_list_manager_, AddConnectionIdToTimeWait(_, _)) .Times(0); ProcessPacket(client_address, connection_id, true, "data", @@ -1675,8 +1886,7 @@ class BlockingWriter : public QuicPacketWriterWrapper { bool IsWriteBlocked() const override { return write_blocked_; } void SetWritable() override { write_blocked_ = false; } - WriteResult WritePacket(const char* /*buffer*/, - size_t /*buf_len*/, + WriteResult WritePacket(const char* /*buffer*/, size_t /*buf_len*/, const QuicIpAddress& /*self_client_address*/, const QuicSocketAddress& /*peer_client_address*/, PerPacketOptions* /*options*/) override { @@ -1976,10 +2186,6 @@ class QuicDispatcherSupportMultipleConnectionIdPerConnectionTest public: QuicDispatcherSupportMultipleConnectionIdPerConnectionTest() : QuicDispatcherTestBase(crypto_test_utils::ProofSourceForTesting()) { - SetQuicRestartFlag(quic_use_reference_counted_sesssion_map, true); - SetQuicRestartFlag(quic_time_wait_list_support_multiple_cid_v2, true); - SetQuicRestartFlag(quic_dispatcher_support_multiple_cid_per_connection_v2, - true); dispatcher_ = std::make_unique>( &config_, &crypto_config_, &version_manager_, mock_helper_.GetRandomGenerator()); @@ -2183,8 +2389,7 @@ class BufferedPacketStoreTest : public QuicDispatcherTestBase { } void ProcessUndecryptableEarlyPacket( - const ParsedQuicVersion& version, - const QuicSocketAddress& peer_address, + const ParsedQuicVersion& version, const QuicSocketAddress& peer_address, const QuicConnectionId& server_connection_id) { QuicDispatcherTestBase::ProcessUndecryptableEarlyPacket( version, peer_address, server_connection_id); @@ -2207,8 +2412,7 @@ class BufferedPacketStoreTest : public QuicDispatcherTestBase { QuicSocketAddress client_addr_; }; -INSTANTIATE_TEST_SUITE_P(BufferedPacketStoreTests, - BufferedPacketStoreTest, +INSTANTIATE_TEST_SUITE_P(BufferedPacketStoreTests, BufferedPacketStoreTest, ::testing::ValuesIn(CurrentSupportedVersions()), ::testing::PrintToStringParamName()); @@ -2228,7 +2432,7 @@ TEST_P(BufferedPacketStoreTest, ProcessNonChloPacketBeforeChlo) { // buffered should be delivered to the session. EXPECT_CALL(*dispatcher_, CreateQuicSession(conn_id, _, client_addr_, Eq(ExpectedAlpn()), _, - TestHostname())) + Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, conn_id, client_addr_, &mock_helper_, &mock_alarm_factory_, &crypto_config_, @@ -2289,7 +2493,7 @@ TEST_P(BufferedPacketStoreTest, // A bunch of non-CHLO should be buffered upon arrival. size_t kNumConnections = kMaxConnectionsWithoutCHLO + 1; for (size_t i = 1; i <= kNumConnections; ++i) { - QuicSocketAddress client_address(QuicIpAddress::Loopback4(), i); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 20000 + i); QuicConnectionId conn_id = TestConnectionId(i); EXPECT_CALL(*dispatcher_, ShouldCreateOrBufferPacketForConnection( @@ -2307,7 +2511,7 @@ TEST_P(BufferedPacketStoreTest, kNumConnections); // Process CHLOs to create session for these connections. for (size_t i = 1; i <= kNumConnections; ++i) { - QuicSocketAddress client_address(QuicIpAddress::Loopback4(), i); + QuicSocketAddress client_address(QuicIpAddress::Loopback4(), 20000 + i); QuicConnectionId conn_id = TestConnectionId(i); if (i == kNumConnections) { EXPECT_CALL(*dispatcher_, @@ -2432,7 +2636,8 @@ TEST_P(BufferedPacketStoreTest, ProcessCHLOsUptoLimitAndBufferTheRest) { if (conn_id <= kMaxNumSessionsToCreate) { EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, - Eq(ExpectedAlpn()), _, TestHostname())) + Eq(ExpectedAlpn()), _, + Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, TestConnectionId(conn_id), client_addr_, &mock_helper_, &mock_alarm_factory_, @@ -2467,7 +2672,8 @@ TEST_P(BufferedPacketStoreTest, ProcessCHLOsUptoLimitAndBufferTheRest) { ++conn_id) { EXPECT_CALL(*dispatcher_, CreateQuicSession(TestConnectionId(conn_id), _, client_addr_, - Eq(ExpectedAlpn()), _, TestHostname())) + Eq(ExpectedAlpn()), _, + Eq(ParsedClientHelloForTest()))) .WillOnce(Return(ByMove(CreateSession( dispatcher_.get(), config_, TestConnectionId(conn_id), client_addr_, &mock_helper_, &mock_alarm_factory_, &crypto_config_, diff --git a/gquiche/quic/core/quic_epoll_alarm_factory_test.cc b/gquiche/quic/core/quic_epoll_alarm_factory_test.cc index f33a1e39..bc3e725a 100644 --- a/gquiche/quic/core/quic_epoll_alarm_factory_test.cc +++ b/gquiche/quic/core/quic_epoll_alarm_factory_test.cc @@ -6,13 +6,13 @@ #include "gquiche/quic/platform/api/quic_epoll_test_tools.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "platform/quic_epoll_clock.h" +#include "platform/quic_platform_impl/quic_epoll_clock.h" namespace quic { namespace test { namespace { -class TestDelegate : public QuicAlarm::Delegate { +class TestDelegate : public QuicAlarm::DelegateWithoutContext { public: TestDelegate() : fired_(false) {} diff --git a/gquiche/quic/core/quic_error_codes.cc b/gquiche/quic/core/quic_error_codes.cc index 113022f0..561a7df8 100644 --- a/gquiche/quic/core/quic_error_codes.cc +++ b/gquiche/quic/core/quic_error_codes.cc @@ -238,6 +238,8 @@ const char* QuicErrorCodeToString(QuicErrorCode error) { RETURN_STRING_LITERAL(QUIC_HTTP_GOAWAY_ID_LARGER_THAN_PREVIOUS); RETURN_STRING_LITERAL(QUIC_HTTP_RECEIVE_SPDY_SETTING); RETURN_STRING_LITERAL(QUIC_HTTP_RECEIVE_SPDY_FRAME); + RETURN_STRING_LITERAL(QUIC_HTTP_RECEIVE_SERVER_PUSH); + RETURN_STRING_LITERAL(QUIC_HTTP_INVALID_SETTING_VALUE); RETURN_STRING_LITERAL(QUIC_HPACK_INDEX_VARINT_ERROR); RETURN_STRING_LITERAL(QUIC_HPACK_NAME_LENGTH_VARINT_ERROR); RETURN_STRING_LITERAL(QUIC_HPACK_VALUE_LENGTH_VARINT_ERROR); @@ -273,6 +275,10 @@ const char* QuicErrorCodeToString(QuicErrorCode error) { RETURN_STRING_LITERAL(QUIC_TLS_INTERNAL_ERROR); RETURN_STRING_LITERAL(QUIC_TLS_UNRECOGNIZED_NAME); RETURN_STRING_LITERAL(QUIC_TLS_CERTIFICATE_REQUIRED); + RETURN_STRING_LITERAL(QUIC_INVALID_CHARACTER_IN_FIELD_VALUE); + RETURN_STRING_LITERAL(QUIC_TLS_UNEXPECTED_KEYING_MATERIAL_EXPORT_LABEL); + RETURN_STRING_LITERAL(QUIC_TLS_KEYING_MATERIAL_EXPORTS_MISMATCH); + RETURN_STRING_LITERAL(QUIC_TLS_KEYING_MATERIAL_EXPORT_NOT_AVAILABLE); RETURN_STRING_LITERAL(QUIC_LAST_ERROR); // Intentionally have no default case, so we'll break the build @@ -285,9 +291,6 @@ const char* QuicErrorCodeToString(QuicErrorCode error) { } std::string QuicIetfTransportErrorCodeString(QuicIetfTransportErrorCodes c) { - if (static_cast(c) >= 0xff00u) { - return absl::StrCat("Private(", static_cast(c), ")"); - } if (c >= CRYPTO_ERROR_FIRST && c <= CRYPTO_ERROR_LAST) { const int tls_error = static_cast(c - CRYPTO_ERROR_FIRST); const char* tls_error_description = SSL_alert_desc_string_long(tls_error); @@ -678,6 +681,9 @@ QuicErrorCodeToIetfMapping QuicErrorCodeToTransportErrorCode( case QUIC_HTTP_STREAM_LIMIT_TOO_LOW: return {false, static_cast( QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR)}; + case QUIC_HTTP_RECEIVE_SERVER_PUSH: + return {false, static_cast( + QuicHttp3ErrorCode::GENERAL_PROTOCOL_ERROR)}; case QUIC_HTTP_ZERO_RTT_RESUMPTION_SETTINGS_MISMATCH: return {false, static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR)}; case QUIC_HTTP_ZERO_RTT_REJECTION_SETTINGS_MISMATCH: @@ -688,6 +694,8 @@ QuicErrorCodeToIetfMapping QuicErrorCodeToTransportErrorCode( return {false, static_cast(QuicHttp3ErrorCode::ID_ERROR)}; case QUIC_HTTP_RECEIVE_SPDY_SETTING: return {false, static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR)}; + case QUIC_HTTP_INVALID_SETTING_VALUE: + return {false, static_cast(QuicHttp3ErrorCode::SETTINGS_ERROR)}; case QUIC_HTTP_RECEIVE_SPDY_FRAME: return {false, static_cast(QuicHttp3ErrorCode::FRAME_UNEXPECTED)}; @@ -768,6 +776,14 @@ QuicErrorCodeToIetfMapping QuicErrorCodeToTransportErrorCode( return {true, static_cast(CONNECTION_ID_LIMIT_ERROR)}; case QUIC_TOO_MANY_CONNECTION_ID_WAITING_TO_RETIRE: return {true, static_cast(INTERNAL_ERROR)}; + case QUIC_INVALID_CHARACTER_IN_FIELD_VALUE: + return {false, static_cast(QuicHttp3ErrorCode::MESSAGE_ERROR)}; + case QUIC_TLS_UNEXPECTED_KEYING_MATERIAL_EXPORT_LABEL: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_TLS_KEYING_MATERIAL_EXPORTS_MISMATCH: + return {true, static_cast(PROTOCOL_VIOLATION)}; + case QUIC_TLS_KEYING_MATERIAL_EXPORT_NOT_AVAILABLE: + return {true, static_cast(PROTOCOL_VIOLATION)}; case QUIC_LAST_ERROR: return {false, static_cast(QUIC_LAST_ERROR)}; } @@ -935,6 +951,19 @@ QuicRstStreamErrorCode IetfResetStreamErrorCodeToRstStreamErrorCode( return QUIC_STREAM_UNKNOWN_APPLICATION_ERROR_CODE; } +// static +QuicResetStreamError QuicResetStreamError::FromInternal( + QuicRstStreamErrorCode code) { + return QuicResetStreamError( + code, RstStreamErrorCodeToIetfResetStreamErrorCode(code)); +} + +// static +QuicResetStreamError QuicResetStreamError::FromIetf(uint64_t code) { + return QuicResetStreamError( + IetfResetStreamErrorCodeToRstStreamErrorCode(code), code); +} + #undef RETURN_STRING_LITERAL // undef for jumbo builds } // namespace quic diff --git a/gquiche/quic/core/quic_error_codes.h b/gquiche/quic/core/quic_error_codes.h index e7eb83f1..c71f8290 100644 --- a/gquiche/quic/core/quic_error_codes.h +++ b/gquiche/quic/core/quic_error_codes.h @@ -459,8 +459,12 @@ enum QuicErrorCode { // Received multiple close offset. QUIC_STREAM_MULTIPLE_OFFSET = 130, - // Internal error codes for HTTP/3 errors. + // HTTP/3 errors. + + // Frame payload larger than what HttpDecoder is willing to buffer. QUIC_HTTP_FRAME_TOO_LARGE = 131, + // Malformed HTTP/3 frame, or PUSH_PROMISE or CANCEL_PUSH received (which is + // an error because MAX_PUSH_ID is never sent). QUIC_HTTP_FRAME_ERROR = 132, // A frame that is never allowed on a request stream is received. QUIC_HTTP_FRAME_UNEXPECTED_ON_SPDY_STREAM = 133, @@ -506,6 +510,11 @@ enum QuicErrorCode { QUIC_HTTP_RECEIVE_SPDY_SETTING = 169, // HTTP/3 session received an HTTP/2 only frame. QUIC_HTTP_RECEIVE_SPDY_FRAME = 171, + // HTTP/3 session received SERVER_PUSH stream, which is an error because + // PUSH_PROMISE is not accepted. + QUIC_HTTP_RECEIVE_SERVER_PUSH = 205, + // HTTP/3 session received invalid SETTING value. + QUIC_HTTP_INVALID_SETTING_VALUE = 207, // HPACK header block decoding errors. // Index varint beyond implementation limit. @@ -592,8 +601,16 @@ enum QuicErrorCode { QUIC_TLS_UNRECOGNIZED_NAME = 201, QUIC_TLS_CERTIFICATE_REQUIRED = 202, + // An HTTP field value containing an invalid character has been received. + QUIC_INVALID_CHARACTER_IN_FIELD_VALUE = 206, + + // Error code related to the usage of TLS keying material export. + QUIC_TLS_UNEXPECTED_KEYING_MATERIAL_EXPORT_LABEL = 208, + QUIC_TLS_KEYING_MATERIAL_EXPORTS_MISMATCH = 209, + QUIC_TLS_KEYING_MATERIAL_EXPORT_NOT_AVAILABLE = 210, + // No error. Used as bound while iterating. - QUIC_LAST_ERROR = 205, + QUIC_LAST_ERROR = 211, }; // QuicErrorCodes is encoded as four octets on-the-wire when doing Google QUIC, // or a varint62 when doing IETF QUIC. Ensure that its value does not exceed @@ -602,6 +619,45 @@ static_assert(static_cast(QUIC_LAST_ERROR) <= static_cast(std::numeric_limits::max()), "QuicErrorCode exceeds four octets"); +// Represents a reason for resetting a stream in both gQUIC and IETF error code +// space. Both error codes have to be present. +class QUIC_EXPORT_PRIVATE QuicResetStreamError { + public: + // Constructs a QuicResetStreamError from QuicRstStreamErrorCode; the IETF + // error code is inferred. + static QuicResetStreamError FromInternal(QuicRstStreamErrorCode code); + // Constructs a QuicResetStreamError from an IETF error code; the internal + // error code is inferred. + static QuicResetStreamError FromIetf(uint64_t code); + // Constructs a QuicResetStreamError with no error. + static QuicResetStreamError NoError() { + return FromInternal(QUIC_STREAM_NO_ERROR); + } + + QuicResetStreamError(QuicRstStreamErrorCode internal_code, + uint64_t ietf_application_code) + : internal_code_(internal_code), + ietf_application_code_(ietf_application_code) {} + + QuicRstStreamErrorCode internal_code() const { return internal_code_; } + uint64_t ietf_application_code() const { return ietf_application_code_; } + + bool operator==(const QuicResetStreamError& other) const { + return internal_code() == other.internal_code() && + ietf_application_code() == other.ietf_application_code(); + } + + // Returns true if the object holds no error. + bool ok() const { return internal_code() == QUIC_STREAM_NO_ERROR; } + + private: + // Error code used in gQUIC. Even when IETF QUIC is in use, this needs to be + // populated as we use those internally. + QuicRstStreamErrorCode internal_code_; + // Application error code used in IETF QUIC. + uint64_t ietf_application_code_; +}; + // Convert TLS alert code to QuicErrorCode. QUIC_EXPORT_PRIVATE QuicErrorCode TlsAlertToQuicErrorCode(uint8_t desc); @@ -638,8 +694,7 @@ QUIC_EXPORT_PRIVATE std::string QuicIetfTransportErrorCodeString( QuicIetfTransportErrorCodes c); QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - const QuicIetfTransportErrorCodes& c); + std::ostream& os, const QuicIetfTransportErrorCodes& c); // A transport error code (if is_transport_close is true) or application error // code (if is_transport_close is false) to be used in CONNECTION_CLOSE frames. @@ -671,6 +726,7 @@ enum class QuicHttp3ErrorCode { REQUEST_REJECTED = 0x10B, REQUEST_CANCELLED = 0x10C, REQUEST_INCOMPLETE = 0x10D, + MESSAGE_ERROR = 0x10E, CONNECT_ERROR = 0x10F, VERSION_FALLBACK = 0x110, }; diff --git a/gquiche/quic/core/quic_error_codes_test.cc b/gquiche/quic/core/quic_error_codes_test.cc index 5ff3757f..581565b1 100644 --- a/gquiche/quic/core/quic_error_codes_test.cc +++ b/gquiche/quic/core/quic_error_codes_test.cc @@ -20,10 +20,6 @@ TEST_F(QuicErrorCodesTest, QuicErrorCodeToString) { } TEST_F(QuicErrorCodesTest, QuicIetfTransportErrorCodeString) { - EXPECT_EQ("Private(65280)", - QuicIetfTransportErrorCodeString( - static_cast(0xff00u))); - EXPECT_EQ("CRYPTO_ERROR(missing extension)", QuicIetfTransportErrorCodeString( static_cast( diff --git a/gquiche/quic/core/quic_flags_list.h b/gquiche/quic/core/quic_flags_list.h index 287f1767..b7b7f155 100644 --- a/gquiche/quic/core/quic_flags_list.h +++ b/gquiche/quic/core/quic_flags_list.h @@ -4,81 +4,145 @@ // This file is autogenerated by the QUICHE Copybara export script. +#ifdef QUIC_FLAG + +QUIC_FLAG(FLAGS_quic_restart_flag_quic_offload_pacing_to_usps2, false) +// A testonly reloadable flag that will always default to false. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_false, false) +// A testonly reloadable flag that will always default to true. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_true, true) +// A testonly restart flag that will always default to false. +QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_false, false) +// A testonly restart flag that will always default to true. +QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_true, true) +// Donot check amplification limit if there is available pending_timer_transmission_count. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_donot_check_amplification_limit_with_pending_timer_credit, false) +// If bytes in flight has dipped below 1.25*MaxBW in the last round, do not exit PROBE_UP due to excess queue buildup. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_no_probe_up_exit_if_no_queue, true) +// If true, 1) NEW_TOKENs sent from a IETF QUIC session will include the cached network parameters proto, 2) A min_rtt received from a validated token will be used to set the initial rtt, 3) Enable bandwidth resumption for IETF QUIC when connection options BWRE or BWMX exists. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_add_cached_network_parameters_to_address_token2, false) +// If true, QUIC will default enable MTU discovery at server, with a target of 1450 bytes. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_mtu_discovery_at_server, false) +// If true, QUIC won\'t honor the connection option TLPR +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_deprecate_tlpr, true) +// If true, QuicGsoBatchWriter will support release time if it is available and the process has the permission to do so. +QUIC_FLAG(FLAGS_quic_restart_flag_quic_support_release_time_for_gso, false) +// If true, TlsServerHandshaker will be able to 1) request client cert, and 2) verify the client cert in the virtual method TlsServerHandshaker::VerifyCertChain. +QUIC_FLAG(FLAGS_quic_restart_flag_quic_tls_server_support_client_cert, true) +// If true, abort async QPACK header decompression in QuicSpdyStream::Reset() and in QuicSpdyStream::OnStreamReset(). QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_abort_qpack_on_stream_reset, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_accept_empty_stream_frame_with_no_fin, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_ack_delay_alarm_granularity, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_add_packet_flusher_on_async_op_done, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_add_stream_info_to_idle_close_detail, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_allocate_stream_sequencer_buffer_blocks_on_demand, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_allow_client_enabled_bbr_v2, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_alps_include_scheme_in_origin, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_and_tls_allow_sni_without_dots, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_batch_writer_fix_write_blocked, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_avoid_too_low_probe_bw_cwnd, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_fix_bw_lo_mode, false) +// If true, accept empty crypto frame. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_accept_empty_crypto_frame, false) +// If true, ack frequency frame can be sent from server to client. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_can_send_ack_frequency, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_close_connection_with_too_many_outstanding_packets, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_connection_support_multiple_cids_v4, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_conservative_bursts, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_conservative_cwnd_and_pacing_gains, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_count_bytes_on_alternative_path_seperately, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_crypto_postpone_cert_validate_for_server, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_enable_5rto_blackhole_detection2, true) +// If true, allow client to enable BBRv2 on server via connection option \'B2ON\'. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_allow_client_enabled_bbr_v2, false) +// If true, change QuicCryptoServerStream::FinishProcessingHandshakeMessageAfterProcessClientHello to noop if connection is disconnected. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_crypto_noop_if_disconnected_after_process_chlo, true) +// If true, clear undecryptable packets on handshake complete. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_clear_undecryptable_packets_on_handshake_complete, true) +// If true, close read side but not write side in QuicSpdyStream::OnStreamReset(). +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_on_stream_reset, true) +// If true, default on PTO which unifies TLP + RTO loss recovery. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_on_pto, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_to_bbr, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_to_bbr_v2, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_delay_initial_ack, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_deprecate_incoming_connection_ids, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_server_blackhole_detection, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_version_draft_29, false) +// If true, default-enable 5RTO blachole detection. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_enable_5rto_blackhole_detection2, true) +// If true, delay block allocation in QuicStreamSequencerBuffer until there is actually new data available. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_delay_sequencer_buffer_allocation_until_new_data, true) +// If true, disable QUIC version Q043. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_version_q043, false) +// If true, disable QUIC version Q046. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_version_q046, false) +// If true, disable QUIC version Q050. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_version_q050, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_version_t051, false) +// If true, disable QUIC version h3 (RFCv1). +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_version_rfcv1, false) +// If true, disable QUIC version h3-29. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_version_draft_29, false) +// If true, disable blackhole detection on server side. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_disable_server_blackhole_detection, false) +// If true, discard INITIAL packet if the key has been dropped. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_discard_initial_packet_with_key_dropped, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_do_not_synthesize_source_cid_for_short_header, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_donot_pto_half_rtt_data, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_donot_reset_ideal_next_packet_send_time, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_donot_write_mid_packet_processing, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_drop_unsent_path_response, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_alps_client, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_alps_server, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_mtu_discovery_at_server, false) +// If true, do not bundle 2nd ACK with connection close if there is an ACK queued. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_single_ack_in_packet2, false) +// If true, do not call ProofSourceHandle::SelectCertificate if QUIC connection has disconnected. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_no_select_cert_if_disconnected, true) +// If true, do not count bytes sent/received on the alternative path into the bytes sent/received on the default path. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_count_bytes_on_alternative_path_seperately, true) +// If true, do not re-arm PTO while sending application data during handshake. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_donot_rearm_pto_on_application_data_during_handshake, true) +// If true, do not use the gQUIC common certificate set for certificate compression. +QUIC_FLAG(FLAGS_quic_restart_flag_quic_no_common_cert_set, true) +// If true, drop unsent PATH_RESPONSEs and rely on peer\'s retry. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_drop_unsent_path_response, true) +// If true, enable server retransmittable on wire PING. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_server_on_wire_ping, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_token_based_address_validation, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_enable_version_rfcv1, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_encrypted_control_frames, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_encrypted_goaway, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_dispatcher_sent_error_code, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_key_update_on_first_packet, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_on_stream_reset, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_stateless_reset, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_willing_and_able_to_write2, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_group_path_response_and_challenge_sending_closer, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_h3_datagram, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_pass_path_response_to_validator, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_preempt_stream_data_with_handshake_packet, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_reject_unexpected_ietf_frame_types, true) +// If true, flush pending frames as well as pending padding bytes on connection migration. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_flush_pending_frames_and_padding_bytes_on_migration, true) +// If true, ietf connection migration is no longer conditioned on connection option RVCM. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_remove_connection_migration_connection_option, false) +// If true, ignore peer_max_ack_delay during handshake. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_ignore_peer_max_ack_delay_during_handshake, true) +// If true, include stream information in idle timeout connection close detail. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_add_stream_info_to_idle_close_detail, true) +// If true, pass the received PATH_RESPONSE payload to path validator to move forward the path validation. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_pass_path_response_to_validator, true) +// If true, quic server will send ENABLE_CONNECT_PROTOCOL setting and and endpoint will validate required request/response headers and extended CONNECT mechanism and update code counts of valid/invalid headers. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_verify_request_headers_2, true) +// If true, record addresses that server has sent reset to recently, and do not send reset if the address lives in the set. +QUIC_FLAG(FLAGS_quic_restart_flag_quic_use_recent_reset_addresses, true) +// If true, reject or send error response code upon receiving invalid request or response headers. This flag depends on --gfe2_reloadable_flag_quic_verify_request_headers_2. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_act_upon_invalid_header, false) +// If true, require handshake confirmation for QUIC connections, functionally disabling 0-rtt handshakes. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_require_handshake_confirmation, false) +// If true, reset per packet state before processing undecryptable packets. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_reset_per_packet_state_for_undecryptable_packets, true) +// If true, send PATH_RESPONSE upon receiving PATH_CHALLENGE regardless of perspective. --gfe2_reloadable_flag_quic_start_peer_migration_earlier has to be true before turn on this flag. QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_path_response2, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_timestamps, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_send_tls_crypto_error_code, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_server_reverse_validate_new_path3, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_single_ack_in_packet2, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_start_peer_migration_earlier, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_false, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_testonly_default_true, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_use_normalized_sni_for_cert_selectioon, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_unified_iw_options, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_connection_id_on_default_path, false) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_encryption_level_context, true) -QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_use_write_or_buffer_data_at_level, false) -QUIC_FLAG(FLAGS_quic_restart_flag_dont_fetch_quic_private_keys_from_leto, false) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_dispatcher_support_multiple_cid_per_connection_v2, true) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_offload_pacing_to_usps2, false) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_session_tickets_always_enabled, true) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_support_release_time_for_gso, false) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_false, false) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_testonly_default_true, true) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_time_wait_list_support_multiple_cid_v2, true) -QUIC_FLAG(FLAGS_quic_restart_flag_quic_use_reference_counted_sesssion_map, true) +// If true, set burst token to 2 in cwnd bootstrapping experiment. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_conservative_bursts, false) +// If true, stop resetting ideal_next_packet_send_time_ in pacing sender. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_donot_reset_ideal_next_packet_send_time, false) +// If true, suppress crypto data write in mid of packet processing. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_suppress_write_mid_packet_processing, true) +// If true, use BBRv2 as the default congestion controller. Takes precedence over --quic_default_to_bbr. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_to_bbr_v2, false) +// If true, use max(max_bw, send_rate) as the estimated bandwidth in QUIC\'s MaxAckHeightTracker. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr_use_send_rate_in_max_ack_height_tracker, true) +// If true, use new connection ID in connection migration. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_connection_migration_use_new_cid_v2, true) +// If true, uses conservative cwnd gain and pacing gain when cwnd gets bootstrapped. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_conservative_cwnd_and_pacing_gains, false) +// If true, validate that peer owns the new address once the server detects peer migration or is probed from that address, and also apply anti-amplification limit while sending to that address. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_server_reverse_validate_new_path3, true) +// If true, when client attempts TLS resumption, use token in session_cache_ instead of cached_states_ in QuicCryptoClientConfig. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_tls_use_token_in_session_cache, true) +// When receiving STOP_SENDING, send a RESET_STREAM with a matching error code. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_match_ietf_reset_code, true) +// When the flag is true, exit STARTUP after the same number of loss events as PROBE_UP. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_startup_probe_up_loss_events, true) +// When true, QUIC server will ignore received key_update_not_yet_supported transport parameter. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_ignore_key_update_not_yet_supported, true) +// When true, QUIC server will ignore received user agent transport parameter and rely on getting that information from HTTP headers. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_ignore_user_agent_transport_parameter, true) +// When true, QUIC will both send and validate the version_information transport parameter. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_version_information, false) +// When true, defaults to BBR congestion control instead of Cubic. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_default_to_bbr, false) +// When true, prevents QUIC\'s PacingSender from generating bursts when the congestion controller is CWND limited and not pacing limited. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_fix_pacing_sender_bursts, false) +// When true, set the initial congestion control window from connection options in QuicSentPacketManager rather than TcpCubicSenderBytes. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_unified_iw_options, true) +// When true, the B203 connection option causes the Bbr2Sender to ignore inflight_hi during PROBE_UP and increase it when the bytes delivered without loss are higher. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_ignore_inflight_hi_in_probe_up, true) +// When true, the B205 connection option enables extra acked in STARTUP, and B204 adds new logic to decrease it whenever max bandwidth increases. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_startup_extra_acked, true) +// When true, the B207 connection option causes BBR2 to exit STARTUP if a persistent queue of 2*BDP has existed for the entire round. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_exit_startup_on_persistent_queue2, true) +// When true, the BBQ0 connection option causes QUIC BBR2 to add bytes_acked to probe_up_acked if the connection hasn\'t been app-limited since inflight_hi was utilized. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_add_bytes_acked_after_inflight_hi_limited, true) +// When true, the BBR4 copt sets the extra_acked window to 20 RTTs and BBR5 sets it to 40 RTTs. +QUIC_FLAG(FLAGS_quic_reloadable_flag_quic_bbr2_extra_acked_window, true) + +#endif + diff --git a/gquiche/quic/core/quic_framer.cc b/gquiche/quic/core/quic_framer.cc index e3d9454b..5d2b0fee 100644 --- a/gquiche/quic/core/quic_framer.cc +++ b/gquiche/quic/core/quic_framer.cc @@ -47,9 +47,8 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_stack_trace.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { @@ -399,8 +398,7 @@ std::string GenerateErrorString(std::string initial_error_string, } // namespace QuicFramer::QuicFramer(const ParsedQuicVersionVector& supported_versions, - QuicTime creation_time, - Perspective perspective, + QuicTime creation_time, Perspective perspective, uint8_t expected_server_connection_id_length) : visitor_(nullptr), error_(QUIC_NO_ERROR), @@ -414,6 +412,7 @@ QuicFramer::QuicFramer(const ParsedQuicVersionVector& supported_versions, perspective_(perspective), validate_flags_(true), process_timestamps_(false), + receive_timestamps_exponent_(0), creation_time_(creation_time), last_timestamp_(QuicTime::Delta::Zero()), support_key_update_for_connection_(false), @@ -430,15 +429,12 @@ QuicFramer::QuicFramer(const ParsedQuicVersionVector& supported_versions, last_written_packet_number_length_(0), peer_ack_delay_exponent_(kDefaultAckDelayExponent), local_ack_delay_exponent_(kDefaultAckDelayExponent), - current_received_frame_type_(0) { + current_received_frame_type_(0), + previously_received_frame_type_(0) { QUICHE_DCHECK(!supported_versions.empty()); version_ = supported_versions_[0]; QUICHE_DCHECK(version_.IsKnown()) << ParsedQuicVersionVectorToString(supported_versions_); - if (do_not_synthesize_source_cid_for_short_header_) { - QUIC_RELOADABLE_FLAG_COUNT_N( - quic_do_not_synthesize_source_cid_for_short_header, 1, 3); - } } QuicFramer::~QuicFramer() {} @@ -1305,81 +1301,45 @@ size_t QuicFramer::GetMinStatelessResetPacketLength() { // static std::unique_ptr QuicFramer::BuildIetfStatelessResetPacket( - QuicConnectionId /*connection_id*/, - size_t received_packet_length, + QuicConnectionId /*connection_id*/, size_t received_packet_length, StatelessResetToken stateless_reset_token) { QUIC_DVLOG(1) << "Building IETF stateless reset packet."; - if (GetQuicReloadableFlag(quic_fix_stateless_reset)) { - if (received_packet_length <= GetMinStatelessResetPacketLength()) { - QUIC_BUG(362045737_1) - << "Tried to build stateless reset packet with received packet " - "length " - << received_packet_length; - return nullptr; - } - // To ensure stateless reset is indistinguishable from a valid packet, - // include the max connection ID length. - size_t len = std::min(received_packet_length - 1, - GetMinStatelessResetPacketLength() + 1 + - kQuicMaxConnectionIdWithLengthPrefixLength); - std::unique_ptr buffer(new char[len]); - QuicDataWriter writer(len, buffer.get()); - // Append random bytes. This randomness only exists to prevent middleboxes - // from comparing the entire packet to a known value. Therefore it has no - // cryptographic use, and does not need a secure cryptographic pseudo-random - // number generator. It's therefore safe to use WriteInsecureRandomBytes. - if (!writer.WriteInsecureRandomBytes(QuicRandom::GetInstance(), - len - kStatelessResetTokenLength)) { - QUIC_BUG(362045737_2) << "Failed to append random bytes of length: " - << len - kStatelessResetTokenLength; - return nullptr; - } - // Change first 2 fixed bits to 01. - buffer[0] &= ~FLAGS_LONG_HEADER; - buffer[0] |= FLAGS_FIXED_BIT; - - // Append stateless reset token. - if (!writer.WriteBytes(&stateless_reset_token, - sizeof(stateless_reset_token))) { - QUIC_BUG(362045737_3) << "Failed to write stateless reset token"; - return nullptr; - } - QUIC_RELOADABLE_FLAG_COUNT(quic_fix_stateless_reset); - return std::make_unique(buffer.release(), len, - /*owns_buffer=*/true); + if (received_packet_length <= GetMinStatelessResetPacketLength()) { + QUICHE_DLOG(ERROR) + << "Tried to build stateless reset packet with received packet " + "length " + << received_packet_length; + return nullptr; } - - size_t len = kPacketHeaderTypeSize + kMinRandomBytesLengthInStatelessReset + - sizeof(stateless_reset_token); + // To ensure stateless reset is indistinguishable from a valid packet, + // include the max connection ID length. + size_t len = std::min(received_packet_length - 1, + GetMinStatelessResetPacketLength() + 1 + + kQuicMaxConnectionIdWithLengthPrefixLength); std::unique_ptr buffer(new char[len]); QuicDataWriter writer(len, buffer.get()); - - uint8_t type = 0; - type |= FLAGS_FIXED_BIT; - type |= FLAGS_SHORT_HEADER_RESERVED_1; - type |= FLAGS_SHORT_HEADER_RESERVED_2; - type |= PacketNumberLengthToOnWireValue(PACKET_1BYTE_PACKET_NUMBER); - - // Append type byte. - if (!writer.WriteUInt8(type)) { - return nullptr; - } - // Append random bytes. This randomness only exists to prevent middleboxes // from comparing the entire packet to a known value. Therefore it has no // cryptographic use, and does not need a secure cryptographic pseudo-random // number generator. It's therefore safe to use WriteInsecureRandomBytes. if (!writer.WriteInsecureRandomBytes(QuicRandom::GetInstance(), - kMinRandomBytesLengthInStatelessReset)) { + len - kStatelessResetTokenLength)) { + QUIC_BUG(362045737_2) << "Failed to append random bytes of length: " + << len - kStatelessResetTokenLength; return nullptr; } + // Change first 2 fixed bits to 01. + buffer[0] &= ~FLAGS_LONG_HEADER; + buffer[0] |= FLAGS_FIXED_BIT; // Append stateless reset token. if (!writer.WriteBytes(&stateless_reset_token, sizeof(stateless_reset_token))) { + QUIC_BUG(362045737_3) << "Failed to write stateless reset token"; return nullptr; } - return std::make_unique(buffer.release(), len, true); + return std::make_unique(buffer.release(), len, + /*owns_buffer=*/true); } // static @@ -1507,6 +1467,7 @@ QuicFramer::BuildIetfVersionNegotiationPacket( } bool QuicFramer::ProcessPacket(const QuicEncryptedPacket& packet) { + QUICHE_DCHECK(!is_processing_packet_) << ENDPOINT << "Nested ProcessPacket"; is_processing_packet_ = true; bool result = ProcessPacketInternal(packet); is_processing_packet_ = false; @@ -1968,8 +1929,10 @@ bool QuicFramer::ProcessIetfDataPacket(QuicDataReader* encrypted_reader, // Handle the payload. if (VersionHasIetfQuicFrames(version_.transport_version)) { current_received_frame_type_ = 0; + previously_received_frame_type_ = 0; if (!ProcessIetfFrameData(&reader, *header, decrypted_level)) { current_received_frame_type_ = 0; + previously_received_frame_type_ = 0; QUICHE_DCHECK_NE(QUIC_NO_ERROR, error_); // ProcessIetfFrameData sets the error. QUICHE_DCHECK_NE("", detailed_error_); @@ -1978,6 +1941,7 @@ bool QuicFramer::ProcessIetfDataPacket(QuicDataReader* encrypted_reader, return false; } current_received_frame_type_ = 0; + previously_received_frame_type_ = 0; } else { if (!ProcessFrameData(&reader, *header)) { QUICHE_DCHECK_NE(QUIC_NO_ERROR, @@ -2774,7 +2738,6 @@ bool QuicFramer::ProcessAndValidateIetfConnectionIdLength( bool QuicFramer::ValidateReceivedConnectionIds(const QuicPacketHeader& header) { bool skip_server_connection_id_validation = - do_not_synthesize_source_cid_for_short_header_ && perspective_ == Perspective::IS_CLIENT && header.form == IETF_QUIC_SHORT_HEADER_PACKET; if (!skip_server_connection_id_validation && @@ -2786,13 +2749,8 @@ bool QuicFramer::ValidateReceivedConnectionIds(const QuicPacketHeader& header) { } bool skip_client_connection_id_validation = - do_not_synthesize_source_cid_for_short_header_ && perspective_ == Perspective::IS_SERVER && header.form == IETF_QUIC_SHORT_HEADER_PACKET; - if (skip_client_connection_id_validation) { - QUIC_RELOADABLE_FLAG_COUNT_N( - quic_do_not_synthesize_source_cid_for_short_header, 2, 3); - } if (!skip_client_connection_id_validation && version_.SupportsClientConnectionIds() && !QuicUtils::IsConnectionIdValidForVersion( @@ -2829,15 +2787,6 @@ bool QuicFramer::ProcessIetfPacketHeader(QuicDataReader* reader, header->destination_connection_id_included = CONNECTION_ID_PRESENT; header->source_connection_id_included = header->version_flag ? CONNECTION_ID_PRESENT : CONNECTION_ID_ABSENT; - if (!do_not_synthesize_source_cid_for_short_header_ && - header->source_connection_id_included == CONNECTION_ID_ABSENT) { - QUICHE_DCHECK(header->source_connection_id.IsEmpty()); - if (perspective_ == Perspective::IS_CLIENT) { - header->source_connection_id = last_serialized_server_connection_id_; - } else { - header->source_connection_id = last_serialized_client_connection_id_; - } - } if (!ValidateReceivedConnectionIds(*header)) { return false; @@ -2925,13 +2874,6 @@ bool QuicFramer::ProcessIetfPacketHeader(QuicDataReader* reader, set_detailed_error("Client connection ID not supported in this version."); return false; } - if (!do_not_synthesize_source_cid_for_short_header_) { - if (perspective_ == Perspective::IS_CLIENT) { - header->source_connection_id = last_serialized_server_connection_id_; - } else { - header->source_connection_id = last_serialized_client_connection_id_; - } - } } return ValidateReceivedConnectionIds(*header); @@ -3195,11 +3137,14 @@ bool QuicFramer::IsIetfFrameTypeExpectedForEncryptionLevel( case ENCRYPTION_INITIAL: case ENCRYPTION_HANDSHAKE: return frame_type == IETF_CRYPTO || frame_type == IETF_ACK || + frame_type == IETF_ACK_ECN || + frame_type == IETF_ACK_RECEIVE_TIMESTAMPS || frame_type == IETF_PING || frame_type == IETF_PADDING || frame_type == IETF_CONNECTION_CLOSE; case ENCRYPTION_ZERO_RTT: - return !(frame_type == IETF_ACK || frame_type == IETF_CRYPTO || - frame_type == IETF_HANDSHAKE_DONE || + return !(frame_type == IETF_ACK || frame_type == IETF_ACK_ECN || + frame_type == IETF_ACK_RECEIVE_TIMESTAMPS || + frame_type == IETF_CRYPTO || frame_type == IETF_HANDSHAKE_DONE || frame_type == IETF_NEW_TOKEN || frame_type == IETF_PATH_RESPONSE || frame_type == IETF_RETIRE_CONNECTION_ID); @@ -3232,21 +3177,16 @@ bool QuicFramer::ProcessIetfFrameData(QuicDataReader* reader, set_detailed_error("Unable to read frame type."); return RaiseError(QUIC_INVALID_FRAME_DATA); } - if (reject_unexpected_ietf_frame_types_) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_reject_unexpected_ietf_frame_types, 1, - 2); - if (!IsIetfFrameTypeExpectedForEncryptionLevel(frame_type, - decrypted_level)) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_reject_unexpected_ietf_frame_types, 2, - 2); - set_detailed_error(absl::StrCat( - "IETF frame type ", - QuicIetfFrameTypeString(static_cast(frame_type)), - " is unexpected at encryption level ", - EncryptionLevelToString(decrypted_level))); - return RaiseError(IETF_QUIC_PROTOCOL_VIOLATION); - } + if (!IsIetfFrameTypeExpectedForEncryptionLevel(frame_type, + decrypted_level)) { + set_detailed_error(absl::StrCat( + "IETF frame type ", + QuicIetfFrameTypeString(static_cast(frame_type)), + " is unexpected at encryption level ", + EncryptionLevelToString(decrypted_level))); + return RaiseError(IETF_QUIC_PROTOCOL_VIOLATION); } + previously_received_frame_type_ = current_received_frame_type_; current_received_frame_type_ = frame_type; // Is now the number of bytes into which the frame type was encoded. @@ -3476,6 +3416,14 @@ bool QuicFramer::ProcessIetfFrameData(QuicDataReader* reader, } break; } + case IETF_ACK_RECEIVE_TIMESTAMPS: + if (!process_timestamps_) { + set_detailed_error("Unsupported frame type."); + QUIC_DLOG(WARNING) + << ENDPOINT << "IETF_ACK_RECEIVE_TIMESTAMPS not supported"; + return RaiseError(QUIC_INVALID_FRAME_DATA); + } + ABSL_FALLTHROUGH_INTENDED; case IETF_ACK_ECN: case IETF_ACK: { QuicAckFrame frame; @@ -4143,7 +4091,12 @@ bool QuicFramer::ProcessIetfAckFrame(QuicDataReader* reader, ack_block_count--; } - if (frame_type == IETF_ACK_ECN) { + if (frame_type == IETF_ACK_RECEIVE_TIMESTAMPS) { + QUICHE_DCHECK(process_timestamps_); + if (!ProcessIetfTimestampsInAckFrame(ack_frame->largest_acked, reader)) { + return false; + } + } else if (frame_type == IETF_ACK_ECN) { ack_frame->ecn_counters_populated = true; if (!reader->ReadVarInt62(&ack_frame->ect_0_count)) { set_detailed_error("Unable to read ack ect_0_count."); @@ -4173,6 +4126,76 @@ bool QuicFramer::ProcessIetfAckFrame(QuicDataReader* reader, return true; } +bool QuicFramer::ProcessIetfTimestampsInAckFrame(QuicPacketNumber largest_acked, + QuicDataReader* reader) { + uint64_t timestamp_range_count; + if (!reader->ReadVarInt62(×tamp_range_count)) { + set_detailed_error("Unable to read receive timestamp range count."); + return false; + } + if (timestamp_range_count == 0) { + return true; + } + + QuicPacketNumber packet_number = largest_acked; + + // Iterate through all timestamp ranges, each of which represents a block of + // contiguous packets for which receive timestamps are being reported. Each + // range is of the form: + // + // Timestamp Range { + // Gap (i), + // Timestamp Delta Count (i), + // Timestamp Delta (i) ..., + // } + for (uint64_t i = 0; i < timestamp_range_count; i++) { + uint64_t gap; + if (!reader->ReadVarInt62(&gap)) { + set_detailed_error("Unable to read receive timestamp gap."); + return false; + } + if (packet_number.ToUint64() < gap) { + set_detailed_error("Receive timestamp gap too high."); + return false; + } + packet_number = packet_number - gap; + uint64_t timestamp_count; + if (!reader->ReadVarInt62(×tamp_count)) { + set_detailed_error("Unable to read receive timestamp count."); + return false; + } + if (packet_number.ToUint64() < timestamp_count) { + set_detailed_error("Receive timestamp count too high."); + return false; + } + for (uint64_t j = 0; j < timestamp_count; j++) { + uint64_t timestamp_delta; + if (!reader->ReadVarInt62(×tamp_delta)) { + set_detailed_error("Unable to read receive timestamp delta."); + return false; + } + // The first timestamp delta is relative to framer creation time; whereas + // subsequent deltas are relative to the previous delta in decreasing + // packet order. + timestamp_delta = timestamp_delta << receive_timestamps_exponent_; + if (i == 0 && j == 0) { + last_timestamp_ = CalculateTimestampFromWire(timestamp_delta); + } else { + last_timestamp_ = last_timestamp_ - + QuicTime::Delta::FromMicroseconds(timestamp_delta); + if (last_timestamp_ < QuicTime::Delta::Zero()) { + set_detailed_error("Receive timestamp delta too high."); + return false; + } + } + visitor_->OnAckTimestamp(packet_number - j, + creation_time_ + last_timestamp_); + } + packet_number = packet_number - (timestamp_count - 1); + } + return true; +} + bool QuicFramer::ProcessStopWaitingFrame(QuicDataReader* reader, const QuicPacketHeader& header, QuicStopWaitingFrame* stop_waiting) { @@ -4434,6 +4457,26 @@ bool QuicFramer::DoKeyUpdate(KeyUpdateReason reason) { previous_decrypter_ = std::move(decrypter_[ENCRYPTION_FORWARD_SECURE]); decrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_decrypter_); encrypter_[ENCRYPTION_FORWARD_SECURE] = std::move(next_encrypter); + switch (reason) { + case KeyUpdateReason::kInvalid: + QUIC_CODE_COUNT(quic_key_update_invalid); + break; + case KeyUpdateReason::kRemote: + QUIC_CODE_COUNT(quic_key_update_remote); + break; + case KeyUpdateReason::kLocalForTests: + QUIC_CODE_COUNT(quic_key_update_local_for_tests); + break; + case KeyUpdateReason::kLocalForInteropRunner: + QUIC_CODE_COUNT(quic_key_update_local_for_interop_runner); + break; + case KeyUpdateReason::kLocalAeadConfidentialityLimit: + QUIC_CODE_COUNT(quic_key_update_local_aead_confidentiality_limit); + break; + case KeyUpdateReason::kLocalKeyUpdateLimitOverride: + QUIC_CODE_COUNT(quic_key_update_local_limit_override); + break; + } visitor_->OnKeyUpdate(reason); return true; } @@ -4872,10 +4915,8 @@ bool QuicFramer::DecryptPayload(size_t udp_packet_length, if ((current_key_phase_first_received_packet_number_.IsInitialized() && header.packet_number > current_key_phase_first_received_packet_number_) || - (GetQuicReloadableFlag(quic_fix_key_update_on_first_packet) && - !current_key_phase_first_received_packet_number_.IsInitialized() && + (!current_key_phase_first_received_packet_number_.IsInitialized() && !key_update_performed_)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_fix_key_update_on_first_packet); if (!next_decrypter_) { next_decrypter_ = visitor_->AdvanceKeysAndCreateCurrentOneRttDecrypter(); @@ -6570,16 +6611,12 @@ void QuicFramer::EnableMultiplePacketNumberSpacesSupport() { QuicErrorCode QuicFramer::ParsePublicHeaderDispatcher( const QuicEncryptedPacket& packet, uint8_t expected_destination_connection_id_length, - PacketHeaderFormat* format, - QuicLongHeaderType* long_packet_type, - bool* version_present, - bool* has_length_prefix, - QuicVersionLabel* version_label, - ParsedQuicVersion* parsed_version, + PacketHeaderFormat* format, QuicLongHeaderType* long_packet_type, + bool* version_present, bool* has_length_prefix, + QuicVersionLabel* version_label, ParsedQuicVersion* parsed_version, QuicConnectionId* destination_connection_id, QuicConnectionId* source_connection_id, - bool* retry_token_present, - absl::string_view* retry_token, + absl::optional* retry_token, std::string* detailed_error) { QuicDataReader reader(packet.data(), packet.length()); if (reader.IsDoneReading()) { @@ -6587,17 +6624,39 @@ QuicErrorCode QuicFramer::ParsePublicHeaderDispatcher( return QUIC_INVALID_PACKET_HEADER; } const uint8_t first_byte = reader.PeekByte(); + if ((first_byte & FLAGS_LONG_HEADER) == 0 && + (first_byte & FLAGS_FIXED_BIT) == 0 && + (first_byte & FLAGS_DEMULTIPLEXING_BIT) == 0) { + // All versions of Google QUIC up to and including Q043 set + // FLAGS_DEMULTIPLEXING_BIT to one on all client-to-server packets. Q044 + // and Q045 were never default-enabled in production. All subsequent + // versions of Google QUIC (starting with Q046) require FLAGS_FIXED_BIT to + // be set to one on all packets. All versions of IETF QUIC (since + // draft-ietf-quic-transport-17 which was earlier than the first IETF QUIC + // version that was deployed in production by any implementation) also + // require FLAGS_FIXED_BIT to be set to one on all packets. If a packet + // has the FLAGS_LONG_HEADER bit set to one, it could be a first flight + // from an unknown future version that allows the other two bits to be set + // to zero. Based on this, packets that have all three of those bits set + // to zero are known to be invalid. + *detailed_error = "Invalid flags."; + return QUIC_INVALID_PACKET_HEADER; + } const bool ietf_format = QuicUtils::IsIetfPacketHeader(first_byte); uint8_t unused_first_byte; QuicVariableLengthIntegerLength retry_token_length_length; + absl::string_view maybe_retry_token; QuicErrorCode error_code = ParsePublicHeader( &reader, expected_destination_connection_id_length, ietf_format, &unused_first_byte, format, version_present, has_length_prefix, version_label, parsed_version, destination_connection_id, source_connection_id, long_packet_type, &retry_token_length_length, - retry_token, detailed_error); - *retry_token_present = - retry_token_length_length != VARIABLE_LENGTH_INTEGER_LENGTH_0; + &maybe_retry_token, detailed_error); + if (retry_token_length_length != VARIABLE_LENGTH_INTEGER_LENGTH_0) { + *retry_token = maybe_retry_token; + } else { + retry_token->reset(); + } return error_code; } diff --git a/gquiche/quic/core/quic_framer.h b/gquiche/quic/core/quic_framer.h index 2cf509d4..26cc0ef6 100644 --- a/gquiche/quic/core/quic_framer.h +++ b/gquiche/quic/core/quic_framer.h @@ -309,10 +309,17 @@ class QUIC_EXPORT_PRIVATE QuicFramer { QuicErrorCode error() const { return error_; } // Allows enabling or disabling of timestamp processing and serialization. - void set_process_timestamps(bool process_timestamps) { + // TODO(ianswett): Remove the const once timestamps are negotiated via + // transport params. + void set_process_timestamps(bool process_timestamps) const { process_timestamps_ = process_timestamps; } + // Sets the exponent to use when writing/reading ACK receive timestamps. + void set_receive_timestamps_exponent(uint32_t exponent) { + receive_timestamps_exponent_ = exponent; + } + // Pass a UDP packet into the framer for parsing. // Return true if the packet was processed successfully. |packet| must be a // single, complete UDP packet (not a frame of a packet). This packet @@ -446,16 +453,12 @@ class QUIC_EXPORT_PRIVATE QuicFramer { static QuicErrorCode ParsePublicHeaderDispatcher( const QuicEncryptedPacket& packet, uint8_t expected_destination_connection_id_length, - PacketHeaderFormat* format, - QuicLongHeaderType* long_packet_type, - bool* version_present, - bool* has_length_prefix, - QuicVersionLabel* version_label, - ParsedQuicVersion* parsed_version, + PacketHeaderFormat* format, QuicLongHeaderType* long_packet_type, + bool* version_present, bool* has_length_prefix, + QuicVersionLabel* version_label, ParsedQuicVersion* parsed_version, QuicConnectionId* destination_connection_id, QuicConnectionId* source_connection_id, - bool* retry_token_present, - absl::string_view* retry_token, + absl::optional* retry_token, std::string* detailed_error); // Serializes a packet containing |frames| into |buffer|. @@ -634,6 +637,8 @@ class QUIC_EXPORT_PRIVATE QuicFramer { Perspective perspective() const { return perspective_; } + QuicStreamFrameDataProducer* data_producer() const { return data_producer_; } + void set_data_producer(QuicStreamFrameDataProducer* data_producer) { data_producer_ = data_producer; } @@ -648,6 +653,10 @@ class QUIC_EXPORT_PRIVATE QuicFramer { return current_received_frame_type_; } + uint64_t previously_received_frame_type() const { + return previously_received_frame_type_; + } + // The connection ID length the framer expects on incoming IETF short headers // on the server. uint8_t GetExpectedServerConnectionIdLength() { @@ -714,10 +723,6 @@ class QUIC_EXPORT_PRIVATE QuicFramer { drop_incoming_retry_packets_ = drop_incoming_retry_packets; } - bool do_not_synthesize_source_cid_for_short_header() const { - return do_not_synthesize_source_cid_for_short_header_; - } - private: friend class test::QuicFramerPeer; @@ -852,6 +857,8 @@ class QUIC_EXPORT_PRIVATE QuicFramer { bool ProcessIetfAckFrame(QuicDataReader* reader, uint64_t frame_type, QuicAckFrame* ack_frame); + bool ProcessIetfTimestampsInAckFrame(QuicPacketNumber largest_acked, + QuicDataReader* reader); bool ProcessStopWaitingFrame(QuicDataReader* reader, const QuicPacketHeader& header, QuicStopWaitingFrame* stop_waiting); @@ -1115,7 +1122,10 @@ class QUIC_EXPORT_PRIVATE QuicFramer { // The diversification nonce from the last received packet. DiversificationNonce last_nonce_; // If true, send and process timestamps in the ACK frame. - bool process_timestamps_; + // TODO(ianswett): Remove the mutable once set_process_timestamps isn't const. + mutable bool process_timestamps_; + // The exponent to use when writing/reading ACK receive timestamps. + uint32_t receive_timestamps_exponent_; // The creation time of the connection, used to calculate timestamps. QuicTime creation_time_; // The last timestamp received if process_timestamps_ is true. @@ -1174,14 +1184,6 @@ class QUIC_EXPORT_PRIVATE QuicFramer { // Indicates whether received RETRY packets should be dropped. bool drop_incoming_retry_packets_ = false; - bool reject_unexpected_ietf_frame_types_ = - GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types); - - // Indicates whether source connection ID should be synthesized when read - // short header packet. - const bool do_not_synthesize_source_cid_for_short_header_ = - GetQuicReloadableFlag(quic_do_not_synthesize_source_cid_for_short_header); - // The length in bytes of the last packet number written to an IETF-framed // packet. size_t last_written_packet_number_length_; @@ -1200,6 +1202,11 @@ class QUIC_EXPORT_PRIVATE QuicFramer { // the Transport Connection Close when there is an error during frame // processing. uint64_t current_received_frame_type_; + + // TODO(haoyuewang) Remove this debug utility. + // The type of the IETF frame preceding the frame currently being processed. 0 + // when not processing a frame or only 1 frame has been processed. + uint64_t previously_received_frame_type_; }; // Look for and parse the error code from the ":" text that diff --git a/gquiche/quic/core/quic_framer_test.cc b/gquiche/quic/core/quic_framer_test.cc index 74cd8268..f168b687 100644 --- a/gquiche/quic/core/quic_framer_test.cc +++ b/gquiche/quic/core/quic_framer_test.cc @@ -34,10 +34,10 @@ #include "gquiche/quic/test_tools/quic_framer_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" #include "gquiche/quic/test_tools/simple_data_producer.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" using testing::_; +using testing::ContainerEq; using testing::Return; namespace quic { @@ -62,9 +62,10 @@ QuicConnectionId FramerTestConnectionIdPlusOne() { } QuicConnectionId FramerTestConnectionIdNineBytes() { - char connection_id_bytes[9] = {0xFE, 0xDC, 0xBA, 0x98, 0x76, - 0x54, 0x32, 0x10, 0x42}; - return QuicConnectionId(connection_id_bytes, sizeof(connection_id_bytes)); + uint8_t connection_id_bytes[9] = {0xFE, 0xDC, 0xBA, 0x98, 0x76, + 0x54, 0x32, 0x10, 0x42}; + return QuicConnectionId(reinterpret_cast(connection_id_bytes), + sizeof(connection_id_bytes)); } const QuicPacketNumber kPacketNumber = QuicPacketNumber(UINT64_C(0x12345678)); @@ -358,7 +359,9 @@ class TestQuicVisitor : public QuicFramerVisitorInterface { ack_frames_.push_back(std::make_unique(ack_frame)); if (VersionHasIetfQuicFrames(transport_version_)) { EXPECT_TRUE(IETF_ACK == framer_->current_received_frame_type() || - IETF_ACK_ECN == framer_->current_received_frame_type()); + IETF_ACK_ECN == framer_->current_received_frame_type() || + IETF_ACK_RECEIVE_TIMESTAMPS == + framer_->current_received_frame_type()); } else { EXPECT_EQ(0u, framer_->current_received_frame_type()); } @@ -370,7 +373,9 @@ class TestQuicVisitor : public QuicFramerVisitorInterface { ack_frames_[ack_frames_.size() - 1]->packets.AddRange(start, end); if (VersionHasIetfQuicFrames(transport_version_)) { EXPECT_TRUE(IETF_ACK == framer_->current_received_frame_type() || - IETF_ACK_ECN == framer_->current_received_frame_type()); + IETF_ACK_ECN == framer_->current_received_frame_type() || + IETF_ACK_RECEIVE_TIMESTAMPS == + framer_->current_received_frame_type()); } else { EXPECT_EQ(0u, framer_->current_received_frame_type()); } @@ -381,7 +386,14 @@ class TestQuicVisitor : public QuicFramerVisitorInterface { QuicTime timestamp) override { ack_frames_[ack_frames_.size() - 1]->received_packet_times.push_back( std::make_pair(packet_number, timestamp)); - EXPECT_EQ(0u, framer_->current_received_frame_type()); + if (VersionHasIetfQuicFrames(transport_version_)) { + EXPECT_TRUE(IETF_ACK == framer_->current_received_frame_type() || + IETF_ACK_ECN == framer_->current_received_frame_type() || + IETF_ACK_RECEIVE_TIMESTAMPS == + framer_->current_received_frame_type()); + } else { + EXPECT_EQ(0u, framer_->current_received_frame_type()); + } return true; } @@ -896,6 +908,11 @@ class QuicFramerTest : public QuicTestWithParam { ((n - 1) * QuicUtils::StreamIdDelta(transport_version)); } + QuicTime CreationTimePlus(uint64_t offset_us) { + return framer_.creation_time() + + QuicTime::Delta::FromMicroseconds(offset_us); + } + test::TestEncrypter* encrypter_; test::TestDecrypter* decrypter_; ParsedQuicVersion version_; @@ -1122,15 +1139,15 @@ TEST_P(QuicFramerTest, PacketHeader) { QuicConnectionId destination_connection_id, source_connection_id; QuicVersionLabel version_label; std::string detailed_error; - bool retry_token_present, use_length_prefix; - absl::string_view retry_token; + bool use_length_prefix; + absl::optional retry_token; ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( *encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_flag, &use_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); - EXPECT_FALSE(retry_token_present); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_FALSE(retry_token.has_value()); EXPECT_FALSE(use_length_prefix); EXPECT_THAT(error_code, IsQuicNoError()); EXPECT_EQ(GOOGLE_QUIC_PACKET, format); @@ -1187,17 +1204,17 @@ TEST_P(QuicFramerTest, LongPacketHeader) { QuicConnectionId destination_connection_id, source_connection_id; QuicVersionLabel version_label; std::string detailed_error; - bool retry_token_present, use_length_prefix; - absl::string_view retry_token; + bool use_length_prefix; + absl::optional retry_token; ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( *encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_flag, &use_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); EXPECT_THAT(error_code, IsQuicNoError()); EXPECT_EQ("", detailed_error); - EXPECT_FALSE(retry_token_present); + EXPECT_FALSE(retry_token.has_value()); EXPECT_FALSE(use_length_prefix); EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); EXPECT_TRUE(version_flag); @@ -1267,16 +1284,16 @@ TEST_P(QuicFramerTest, LongPacketHeaderWithBothConnectionIds) { QuicConnectionId destination_connection_id, source_connection_id; QuicVersionLabel version_label = 0; std::string detailed_error = ""; - bool retry_token_present, use_length_prefix; - absl::string_view retry_token; + bool use_length_prefix; + absl::optional retry_token; ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_flag, &use_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); EXPECT_THAT(error_code, IsQuicNoError()); - EXPECT_FALSE(retry_token_present); + EXPECT_FALSE(retry_token.has_value()); EXPECT_EQ(framer_.version().HasLengthPrefixedConnectionIds(), use_length_prefix); EXPECT_EQ("", detailed_error); @@ -1286,6 +1303,27 @@ TEST_P(QuicFramerTest, LongPacketHeaderWithBothConnectionIds) { EXPECT_EQ(FramerTestConnectionIdPlusOne(), source_connection_id); } +TEST_P(QuicFramerTest, AllZeroPacketParsingFails) { + unsigned char packet[1200] = {}; + QuicEncryptedPacket encrypted(AsChars(packet), ABSL_ARRAYSIZE(packet), false); + PacketHeaderFormat format = GOOGLE_QUIC_PACKET; + QuicLongHeaderType long_packet_type = INVALID_PACKET_TYPE; + bool version_flag = false; + QuicConnectionId destination_connection_id, source_connection_id; + QuicVersionLabel version_label = 0; + std::string detailed_error = ""; + bool use_length_prefix; + absl::optional retry_token; + ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); + const QuicErrorCode error_code = QuicFramer::ParsePublicHeaderDispatcher( + encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, + &version_flag, &use_length_prefix, &version_label, &parsed_version, + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); + EXPECT_EQ(error_code, QUIC_INVALID_PACKET_HEADER); + EXPECT_EQ(detailed_error, "Invalid flags."); +} + TEST_P(QuicFramerTest, ParsePublicHeader) { // clang-format off unsigned char packet[] = { @@ -1479,9 +1517,6 @@ TEST_P(QuicFramerTest, ClientConnectionIdFromShortHeaderToClient) { ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(FramerTestConnectionId(), visitor_.header_->destination_connection_id); - if (!framer_.do_not_synthesize_source_cid_for_short_header()) { - EXPECT_EQ(TestConnectionId(0x33), visitor_.header_->source_connection_id); - } } // In short header packets from client to server, the client connection ID @@ -1515,9 +1550,6 @@ TEST_P(QuicFramerTest, ClientConnectionIdFromShortHeaderToServer) { ASSERT_TRUE(visitor_.header_.get()); EXPECT_EQ(FramerTestConnectionId(), visitor_.header_->destination_connection_id); - if (!framer_.do_not_synthesize_source_cid_for_short_header()) { - EXPECT_EQ(TestConnectionId(0x33), visitor_.header_->source_connection_id); - } } TEST_P(QuicFramerTest, PacketHeaderWith0ByteConnectionId) { @@ -1567,9 +1599,6 @@ TEST_P(QuicFramerTest, PacketHeaderWith0ByteConnectionId) { EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsError(QUIC_MISSING_PAYLOAD)); ASSERT_TRUE(visitor_.header_.get()); - if (!framer_.do_not_synthesize_source_cid_for_short_header()) { - EXPECT_EQ(FramerTestConnectionId(), visitor_.header_->source_connection_id); - } EXPECT_FALSE(visitor_.header_->reset_flag); EXPECT_FALSE(visitor_.header_->version_flag); EXPECT_EQ(kPacketNumber, visitor_.header_->packet_number); @@ -2148,7 +2177,7 @@ TEST_P(QuicFramerTest, PaddingFrame) { 0x00, 0x00, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -2180,8 +2209,8 @@ TEST_P(QuicFramerTest, PaddingFrame) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -2269,7 +2298,7 @@ TEST_P(QuicFramerTest, StreamFrame) { 'r', 'l', 'd', '!'}}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -2302,7 +2331,7 @@ TEST_P(QuicFramerTest, StreamFrame) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -2592,7 +2621,7 @@ TEST_P(QuicFramerTest, StreamFrame2ByteStreamId) { 'r', 'l', 'd', '!'}}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -2625,7 +2654,7 @@ TEST_P(QuicFramerTest, StreamFrame2ByteStreamId) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -2711,7 +2740,7 @@ TEST_P(QuicFramerTest, StreamFrame1ByteStreamId) { 'r', 'l', 'd', '!'}}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -2744,7 +2773,7 @@ TEST_P(QuicFramerTest, StreamFrame1ByteStreamId) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -2889,7 +2918,7 @@ TEST_P(QuicFramerTest, StreamFrameWithVersion) { 'r', 'l', 'd', '!'}}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // public flags (long header with packet type ZERO_RTT_PROTECTED and // 4-byte packet number) {"", @@ -2943,7 +2972,7 @@ TEST_P(QuicFramerTest, StreamFrameWithVersion) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasLongHeaderLengths() ? packet49 : (framer_.version().HasIetfInvariantHeader() ? packet46 @@ -3138,7 +3167,7 @@ TEST_P(QuicFramerTest, AckFrameOneAckBlock) { {0x00}} }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short packet, 4 byte packet number) {"", {0x43}}, @@ -3176,7 +3205,7 @@ TEST_P(QuicFramerTest, AckFrameOneAckBlock) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -3259,7 +3288,7 @@ TEST_P(QuicFramerTest, FirstAckFrameUnderflow) { {0x00}} }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -3289,7 +3318,7 @@ TEST_P(QuicFramerTest, FirstAckFrameUnderflow) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -3309,7 +3338,7 @@ TEST_P(QuicFramerTest, ThirdAckBlockUnderflowGap) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -3349,12 +3378,12 @@ TEST_P(QuicFramerTest, ThirdAckBlockUnderflowGap) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); EXPECT_EQ( framer_.detailed_error(), "Underflow with gap block length 30 previous ack block start is 30."); - CheckFramingBoundaries(packet99, QUIC_INVALID_ACK_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); } // This test checks that the ack frame processor correctly identifies @@ -3369,7 +3398,7 @@ TEST_P(QuicFramerTest, ThirdAckBlockUnderflowAck) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -3407,11 +3436,11 @@ TEST_P(QuicFramerTest, ThirdAckBlockUnderflowAck) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); EXPECT_EQ(framer_.detailed_error(), "Underflow with ack block length 31 latest ack block end is 25."); - CheckFramingBoundaries(packet99, QUIC_INVALID_ACK_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); } // Tests a variety of ack block wrap scenarios. For example, if the @@ -3427,7 +3456,7 @@ TEST_P(QuicFramerTest, AckBlockUnderflowGapWrap) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -3461,11 +3490,11 @@ TEST_P(QuicFramerTest, AckBlockUnderflowGapWrap) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); EXPECT_EQ(framer_.detailed_error(), "Underflow with gap block length 2 previous ack block start is 1."); - CheckFramingBoundaries(packet99, QUIC_INVALID_ACK_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); } // As AckBlockUnderflowGapWrap, but in this test, it's the ack @@ -3479,7 +3508,7 @@ TEST_P(QuicFramerTest, AckBlockUnderflowAckWrap) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -3513,11 +3542,11 @@ TEST_P(QuicFramerTest, AckBlockUnderflowAckWrap) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); EXPECT_EQ(framer_.detailed_error(), "Underflow with ack block length 10 latest ack block end is 1."); - CheckFramingBoundaries(packet99, QUIC_INVALID_ACK_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_ACK_DATA); } // An ack block that acks the entire range, 1...0x3fffffffffffffff @@ -3530,7 +3559,7 @@ TEST_P(QuicFramerTest, AckBlockAcksEverything) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -3561,7 +3590,7 @@ TEST_P(QuicFramerTest, AckBlockAcksEverything) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_EQ(1u, visitor_.ack_frames_.size()); const QuicAckFrame& frame = *visitor_.ack_frames_[0]; @@ -3746,7 +3775,7 @@ TEST_P(QuicFramerTest, AckFrameOneAckBlockMaxLength) { {0x00}} }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -3776,7 +3805,7 @@ TEST_P(QuicFramerTest, AckFrameOneAckBlockMaxLength) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -3938,7 +3967,7 @@ TEST_P(QuicFramerTest, AckFrameTwoTimeStampsMultipleAckBlocks) { { 0x32, 0x10 }}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", { 0x43 }}, @@ -3949,9 +3978,9 @@ TEST_P(QuicFramerTest, AckFrameTwoTimeStampsMultipleAckBlocks) { {"", { 0x12, 0x34, 0x56, 0x78 }}, - // frame type (IETF_ACK frame) + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) {"", - { 0x02 }}, + { 0x22 }}, // largest acked {"Unable to read largest acked.", { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 @@ -3992,12 +4021,24 @@ TEST_P(QuicFramerTest, AckFrameTwoTimeStampsMultipleAckBlocks) { // ack block length. { "Unable to read ack block value.", { kVarInt62OneByte + 0x03 }}, // block is 3 packets. + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62FourBytes + 0x36, 0x54, 0x32, 0x10 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x32, 0x10 }}, }; // clang-format on PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( @@ -4018,12 +4059,331 @@ TEST_P(QuicFramerTest, AckFrameTwoTimeStampsMultipleAckBlocks) { EXPECT_EQ(kSmallLargestObserved, LargestAcked(frame)); ASSERT_EQ(4254u, frame.packets.NumPacketsSlow()); EXPECT_EQ(4u, frame.packets.NumIntervals()); - if (VersionHasIetfQuicFrames(framer_.transport_version())) { - EXPECT_EQ(0u, frame.received_packet_times.size()); - } else { - EXPECT_EQ(2u, frame.received_packet_times.size()); + EXPECT_EQ(2u, frame.received_packet_times.size()); +} + +TEST_P(QuicFramerTest, AckFrameMultipleReceiveTimestampRanges) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; } - CheckFramingBoundaries(fragments, QUIC_INVALID_ACK_DATA); + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x03 }}, + + // Timestamp range 1 (three packets). + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x03 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62FourBytes + 0x29, 0xff, 0xff, 0xff}}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x11, 0x11 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x01}}, + + // Timestamp range 2 (one packet). + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x07 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x10, 0x00 }}, + + // Timestamp range 3 (two packets). + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x0a }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x10 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x01, 0x00 }}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + // Timestamp Range 1. + {LargestAcked(frame) - 2, CreationTimePlus(0x29ffffff)}, + {LargestAcked(frame) - 3, CreationTimePlus(0x29ffeeee)}, + {LargestAcked(frame) - 4, CreationTimePlus(0x29ffeeed)}, + // Timestamp Range 2. + {LargestAcked(frame) - 11, CreationTimePlus(0x29ffdeed)}, + // Timestamp Range 3. + {LargestAcked(frame) - 21, CreationTimePlus(0x29ffdedd)}, + {LargestAcked(frame) - 22, CreationTimePlus(0x29ffdddd)}, + })); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampWithExponent) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x00 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x03 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x29, 0xff}}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x11, 0x11 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x01}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_receive_timestamps_exponent(3); + framer_.set_process_timestamps(true); + EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); + EXPECT_THAT(framer_.error(), IsQuicNoError()); + ASSERT_TRUE(visitor_.header_.get()); + const QuicAckFrame& frame = *visitor_.ack_frames_[0]; + + EXPECT_THAT(frame.received_packet_times, + ContainerEq(PacketTimeVector{ + // Timestamp Range 1. + {LargestAcked(frame), CreationTimePlus(0x29ff << 3)}, + {LargestAcked(frame) - 1, CreationTimePlus(0x18ee << 3)}, + {LargestAcked(frame) - 2, CreationTimePlus(0x18ed << 3)}, + })); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampGapTooHigh) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x79 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x29, 0xff}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "Receive timestamp gap too high.")); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampCountTooHigh) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp count.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x0a}}, + { "Unable to read receive timestamp delta.", + { kVarInt62OneByte + 0x0b}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "Receive timestamp delta too high.")); +} + +TEST_P(QuicFramerTest, AckFrameReceiveTimestampDeltaTooHigh) { + if (!VersionHasIetfQuicFrames(framer_.transport_version())) { + return; + } + SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); + // clang-format off + PacketFragments packet_ietf = { + // type (short header, 4 byte packet number) + {"", + { 0x43 }}, + // connection_id + {"", + { 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10 }}, + // packet number + {"", + { 0x12, 0x34, 0x56, 0x78 }}, + + // frame type (IETF_ACK_RECEIVE_TIMESTAMPS frame) + {"", + { 0x22 }}, + // largest acked + {"Unable to read largest acked.", + { kVarInt62TwoBytes + 0x12, 0x34 }}, // = 4660 + // Zero delta time. + {"Unable to read ack delay time.", + { kVarInt62OneByte + 0x00 }}, + // number of additional ack blocks + {"Unable to read ack block count.", + { kVarInt62OneByte + 0x00 }}, + // first ack block length. + {"Unable to read first ack block length.", + { kVarInt62OneByte + 0x00 }}, // 1st block length = 1 + + // Receive Timestamps. + { "Unable to read receive timestamp range count.", + { kVarInt62OneByte + 0x01 }}, + { "Unable to read receive timestamp gap.", + { kVarInt62OneByte + 0x02 }}, + { "Unable to read receive timestamp count.", + { kVarInt62FourBytes + 0x12, 0x34, 0x56, 0x77 }}, + { "Unable to read receive timestamp delta.", + { kVarInt62TwoBytes + 0x29, 0xff}}, + }; + // clang-format on + + std::unique_ptr encrypted( + AssemblePacketFromFragments(packet_ietf)); + + framer_.set_process_timestamps(true); + EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); + EXPECT_TRUE(absl::StartsWith(framer_.detailed_error(), + "Receive timestamp count too high.")); } TEST_P(QuicFramerTest, AckFrameTimeStampDeltaTooHigh) { @@ -4328,7 +4688,7 @@ TEST_P(QuicFramerTest, RstStreamFrame) { {0x00, 0x00, 0x00, 0x06}} }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -4356,7 +4716,7 @@ TEST_P(QuicFramerTest, RstStreamFrame) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -4433,7 +4793,7 @@ TEST_P(QuicFramerTest, ConnectionCloseFrame) { } }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -4467,7 +4827,7 @@ TEST_P(QuicFramerTest, ConnectionCloseFrame) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -4558,7 +4918,7 @@ TEST_P(QuicFramerTest, ConnectionCloseFrameWithUnknownErrorCode) { } }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -4592,7 +4952,7 @@ TEST_P(QuicFramerTest, ConnectionCloseFrameWithUnknownErrorCode) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -4688,7 +5048,7 @@ TEST_P(QuicFramerTest, ConnectionCloseFrameWithExtractedInfoIgnoreGCuic) { } }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -4722,7 +5082,7 @@ TEST_P(QuicFramerTest, ConnectionCloseFrameWithExtractedInfoIgnoreGCuic) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -4765,7 +5125,7 @@ TEST_P(QuicFramerTest, ApplicationCloseFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -4795,7 +5155,7 @@ TEST_P(QuicFramerTest, ApplicationCloseFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -4814,7 +5174,7 @@ TEST_P(QuicFramerTest, ApplicationCloseFrame) { ASSERT_EQ(0u, visitor_.ack_frames_.size()); - CheckFramingBoundaries(packet99, QUIC_INVALID_CONNECTION_CLOSE_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_CONNECTION_CLOSE_DATA); } // Check that we can extract an error code from an application close. @@ -4826,7 +5186,7 @@ TEST_P(QuicFramerTest, ApplicationCloseFrameExtract) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -4857,7 +5217,7 @@ TEST_P(QuicFramerTest, ApplicationCloseFrameExtract) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -4876,7 +5236,7 @@ TEST_P(QuicFramerTest, ApplicationCloseFrameExtract) { ASSERT_EQ(0u, visitor_.ack_frames_.size()); - CheckFramingBoundaries(packet99, QUIC_INVALID_CONNECTION_CLOSE_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_CONNECTION_CLOSE_DATA); } TEST_P(QuicFramerTest, GoAwayFrame) { @@ -5140,7 +5500,7 @@ TEST_P(QuicFramerTest, MaxDataFrame) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -5161,7 +5521,7 @@ TEST_P(QuicFramerTest, MaxDataFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -5174,7 +5534,7 @@ TEST_P(QuicFramerTest, MaxDataFrame) { visitor_.window_update_frame_.stream_id); EXPECT_EQ(kStreamOffset, visitor_.window_update_frame_.max_data); - CheckFramingBoundaries(packet99, QUIC_INVALID_MAX_DATA_FRAME_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_MAX_DATA_FRAME_DATA); } TEST_P(QuicFramerTest, MaxStreamDataFrame) { @@ -5184,7 +5544,7 @@ TEST_P(QuicFramerTest, MaxStreamDataFrame) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -5208,7 +5568,7 @@ TEST_P(QuicFramerTest, MaxStreamDataFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -5220,7 +5580,7 @@ TEST_P(QuicFramerTest, MaxStreamDataFrame) { EXPECT_EQ(kStreamId, visitor_.window_update_frame_.stream_id); EXPECT_EQ(kStreamOffset, visitor_.window_update_frame_.max_data); - CheckFramingBoundaries(packet99, QUIC_INVALID_MAX_STREAM_DATA_FRAME_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_MAX_STREAM_DATA_FRAME_DATA); } TEST_P(QuicFramerTest, BlockedFrame) { @@ -5262,7 +5622,7 @@ TEST_P(QuicFramerTest, BlockedFrame) { {0x01, 0x02, 0x03, 0x04}}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -5286,7 +5646,7 @@ TEST_P(QuicFramerTest, BlockedFrame) { PacketFragments& fragments = VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet); std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); @@ -5339,7 +5699,7 @@ TEST_P(QuicFramerTest, PingFrame) { 0x07, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -5354,11 +5714,11 @@ TEST_P(QuicFramerTest, PingFrame) { QuicEncryptedPacket encrypted( AsChars(VersionHasIetfQuicFrames(framer_.transport_version()) - ? packet99 + ? packet_ietf : (framer_.version().HasIetfInvariantHeader() ? packet46 : packet)), VersionHasIetfQuicFrames(framer_.transport_version()) - ? ABSL_ARRAYSIZE(packet99) + ? ABSL_ARRAYSIZE(packet_ietf) : (framer_.version().HasIetfInvariantHeader() ? ABSL_ARRAYSIZE(packet46) : ABSL_ARRAYSIZE(packet)), @@ -5485,7 +5845,7 @@ TEST_P(QuicFramerTest, MessageFrame) { {{}, {'m', 'e', 's', 's', 'a', 'g', 'e', '2'}}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -5515,7 +5875,7 @@ TEST_P(QuicFramerTest, MessageFrame) { std::unique_ptr encrypted; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - encrypted = AssemblePacketFromFragments(packet99); + encrypted = AssemblePacketFromFragments(packet_ietf); } else { encrypted = AssemblePacketFromFragments(packet46); } @@ -5532,7 +5892,7 @@ TEST_P(QuicFramerTest, MessageFrame) { EXPECT_EQ(8u, visitor_.message_frames_[1]->message_length); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - CheckFramingBoundaries(packet99, QUIC_INVALID_MESSAGE_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_MESSAGE_DATA); } else { CheckFramingBoundaries(packet46, QUIC_INVALID_MESSAGE_DATA); } @@ -6162,7 +6522,7 @@ TEST_P(QuicFramerTest, BuildPaddingFramePacket) { 0x00, 0x00, 0x00, 0x00 }; - unsigned char packet99[kMaxOutgoingPacketSize] = { + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -6178,7 +6538,7 @@ TEST_P(QuicFramerTest, BuildPaddingFramePacket) { unsigned char* p = packet; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; + p = packet_ietf; } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; } @@ -6267,7 +6627,7 @@ TEST_P(QuicFramerTest, BuildStreamFramePacketWithNewPaddingFrame) { 0x00, 0x00, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -6301,8 +6661,8 @@ TEST_P(QuicFramerTest, BuildStreamFramePacketWithNewPaddingFrame) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -6351,7 +6711,7 @@ TEST_P(QuicFramerTest, Build4ByteSequenceNumberPaddingFramePacket) { 0x00, 0x00, 0x00, 0x00 }; - unsigned char packet99[kMaxOutgoingPacketSize] = { + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -6367,7 +6727,7 @@ TEST_P(QuicFramerTest, Build4ByteSequenceNumberPaddingFramePacket) { unsigned char* p = packet; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; + p = packet_ietf; } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; } @@ -6426,7 +6786,7 @@ TEST_P(QuicFramerTest, Build2ByteSequenceNumberPaddingFramePacket) { 0x00, 0x00, 0x00, 0x00 }; - unsigned char packet99[kMaxOutgoingPacketSize] = { + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { // type (short header, 2 byte packet number) 0x41, // connection_id @@ -6442,7 +6802,7 @@ TEST_P(QuicFramerTest, Build2ByteSequenceNumberPaddingFramePacket) { unsigned char* p = packet; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; + p = packet_ietf; } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; } @@ -6501,7 +6861,7 @@ TEST_P(QuicFramerTest, Build1ByteSequenceNumberPaddingFramePacket) { 0x00, 0x00, 0x00, 0x00 }; - unsigned char packet99[kMaxOutgoingPacketSize] = { + unsigned char packet_ietf[kMaxOutgoingPacketSize] = { // type (short header, 1 byte packet number) 0x40, // connection_id @@ -6517,7 +6877,7 @@ TEST_P(QuicFramerTest, Build1ByteSequenceNumberPaddingFramePacket) { unsigned char* p = packet; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; + p = packet_ietf; } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; } @@ -6597,7 +6957,7 @@ TEST_P(QuicFramerTest, BuildStreamFramePacket) { 'r', 'l', 'd', '!', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -6625,8 +6985,8 @@ TEST_P(QuicFramerTest, BuildStreamFramePacket) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -6721,7 +7081,7 @@ TEST_P(QuicFramerTest, BuildStreamFramePacketWithVersionFlag) { 'h', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', '!', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (long header with packet type ZERO_RTT_PROTECTED) 0xD3, // version tag @@ -6755,8 +7115,8 @@ TEST_P(QuicFramerTest, BuildStreamFramePacketWithVersionFlag) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasLongHeaderLengths()) { p = packet49; p_size = ABSL_ARRAYSIZE(packet49); @@ -6812,7 +7172,7 @@ TEST_P(QuicFramerTest, BuildCryptoFramePacket) { 'r', 'l', 'd', '!', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -6837,8 +7197,8 @@ TEST_P(QuicFramerTest, BuildCryptoFramePacket) { unsigned char* packet = packet48; size_t packet_size = ABSL_ARRAYSIZE(packet48); if (framer_.version().HasIetfQuicFrames()) { - packet = packet99; - packet_size = ABSL_ARRAYSIZE(packet99); + packet = packet_ietf; + packet_size = ABSL_ARRAYSIZE(packet_ietf); } std::unique_ptr data(BuildDataPacket(header, frames)); @@ -6883,7 +7243,7 @@ TEST_P(QuicFramerTest, CryptoFrame) { 'r', 'l', 'd', '!'}}, }; - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -6912,7 +7272,7 @@ TEST_P(QuicFramerTest, CryptoFrame) { // clang-format on PacketFragments& fragments = - framer_.version().HasIetfQuicFrames() ? packet99 : packet48; + framer_.version().HasIetfQuicFrames() ? packet_ietf : packet48; std::unique_ptr encrypted( AssemblePacketFromFragments(fragments)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); @@ -7087,7 +7447,7 @@ TEST_P(QuicFramerTest, BuildAckFramePacketOneAckBlock) { 0x00, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -7110,8 +7470,8 @@ TEST_P(QuicFramerTest, BuildAckFramePacketOneAckBlock) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -7180,7 +7540,7 @@ TEST_P(QuicFramerTest, BuildAckFramePacketOneAckBlockMaxLength) { }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -7203,8 +7563,8 @@ TEST_P(QuicFramerTest, BuildAckFramePacketOneAckBlockMaxLength) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -7313,7 +7673,7 @@ TEST_P(QuicFramerTest, BuildAckFramePacketMultipleAckBlocks) { 0x00, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -7351,8 +7711,8 @@ TEST_P(QuicFramerTest, BuildAckFramePacketMultipleAckBlocks) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -7572,7 +7932,7 @@ TEST_P(QuicFramerTest, BuildAckFramePacketMaxAckBlocks) { 0x00, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -7660,8 +8020,8 @@ TEST_P(QuicFramerTest, BuildAckFramePacketMaxAckBlocks) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -7771,7 +8131,7 @@ TEST_P(QuicFramerTest, BuildRstFramePacketQuic) { 0x05, 0x06, 0x07, 0x08, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short packet, 4 byte packet number) 0x43, // connection_id @@ -7798,8 +8158,8 @@ TEST_P(QuicFramerTest, BuildRstFramePacketQuic) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -7866,7 +8226,7 @@ TEST_P(QuicFramerTest, BuildCloseFramePacket) { 'n', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -7893,8 +8253,8 @@ TEST_P(QuicFramerTest, BuildCloseFramePacket) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -7970,7 +8330,7 @@ TEST_P(QuicFramerTest, BuildCloseFramePacketExtendedInfo) { 'n', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -7999,8 +8359,8 @@ TEST_P(QuicFramerTest, BuildCloseFramePacketExtendedInfo) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -8125,7 +8485,7 @@ TEST_P(QuicFramerTest, BuildTruncatedCloseFramePacket) { 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -8180,8 +8540,8 @@ TEST_P(QuicFramerTest, BuildTruncatedCloseFramePacket) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -8215,7 +8575,7 @@ TEST_P(QuicFramerTest, BuildApplicationCloseFramePacket) { // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -8241,8 +8601,8 @@ TEST_P(QuicFramerTest, BuildApplicationCloseFramePacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BuildTruncatedApplicationCloseFramePacket) { @@ -8268,7 +8628,7 @@ TEST_P(QuicFramerTest, BuildTruncatedApplicationCloseFramePacket) { QuicFrames frames = {QuicFrame(&app_close_frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -8322,8 +8682,8 @@ TEST_P(QuicFramerTest, BuildTruncatedApplicationCloseFramePacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BuildGoAwayPacket) { @@ -8594,7 +8954,7 @@ TEST_P(QuicFramerTest, BuildWindowUpdatePacket) { 0x55, 0x66, 0x77, 0x88, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -8618,8 +8978,8 @@ TEST_P(QuicFramerTest, BuildWindowUpdatePacket) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -8648,7 +9008,7 @@ TEST_P(QuicFramerTest, BuildMaxStreamDataPacket) { QuicFrames frames = {QuicFrame(&window_update_frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -8670,8 +9030,8 @@ TEST_P(QuicFramerTest, BuildMaxStreamDataPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BuildMaxDataPacket) { @@ -8694,7 +9054,7 @@ TEST_P(QuicFramerTest, BuildMaxDataPacket) { QuicFrames frames = {QuicFrame(&window_update_frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -8714,8 +9074,8 @@ TEST_P(QuicFramerTest, BuildMaxDataPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BuildBlockedPacket) { @@ -8769,7 +9129,7 @@ TEST_P(QuicFramerTest, BuildBlockedPacket) { 0x01, 0x02, 0x03, 0x04, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short packet, 4 byte packet number) 0x43, // connection_id @@ -8790,8 +9150,8 @@ TEST_P(QuicFramerTest, BuildBlockedPacket) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -8836,7 +9196,7 @@ TEST_P(QuicFramerTest, BuildPingPacket) { 0x07, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -8851,7 +9211,7 @@ TEST_P(QuicFramerTest, BuildPingPacket) { unsigned char* p = packet; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; + p = packet_ietf; } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; } @@ -8956,10 +9316,9 @@ TEST_P(QuicFramerTest, BuildMessagePacket) { header.reset_flag = false; header.version_flag = false; header.packet_number = kPacketNumber; - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); - QuicMessageFrame frame(1, MakeSpan(&allocator_, "message", &storage)); - QuicMessageFrame frame2(2, MakeSpan(&allocator_, "message2", &storage)); + QuicMessageFrame frame(1, MemSliceFromString("message")); + QuicMessageFrame frame2(2, MemSliceFromString("message2")); QuicFrames frames = {QuicFrame(&frame), QuicFrame(&frame2)}; // clang-format off @@ -8983,7 +9342,7 @@ TEST_P(QuicFramerTest, BuildMessagePacket) { 'm', 'e', 's', 's', 'a', 'g', 'e', '2' }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -9006,7 +9365,7 @@ TEST_P(QuicFramerTest, BuildMessagePacket) { unsigned char* p = packet46; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; + p = packet_ietf; } std::unique_ptr data(BuildDataPacket(header, frames)); @@ -9053,7 +9412,7 @@ TEST_P(QuicFramerTest, BuildMtuDiscoveryPacket) { 0x07, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -9071,7 +9430,7 @@ TEST_P(QuicFramerTest, BuildMtuDiscoveryPacket) { unsigned char* p = packet; if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; + p = packet_ietf; } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; } @@ -9248,8 +9607,7 @@ TEST_P(QuicFramerTest, BuildPublicResetPacketWithEndpointId) { } TEST_P(QuicFramerTest, BuildIetfStatelessResetPacket) { - if (GetQuicReloadableFlag(quic_fix_stateless_reset)) { - // clang-format off + // clang-format off unsigned char packet[] = { // 1st byte 01XX XXXX 0x40, @@ -9259,77 +9617,45 @@ TEST_P(QuicFramerTest, BuildIetfStatelessResetPacket) { 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f }; - // clang-format on - - // Build the minimal stateless reset packet. - std::unique_ptr data( - framer_.BuildIetfStatelessResetPacket( - FramerTestConnectionId(), - QuicFramer::GetMinStatelessResetPacketLength() + 1, - kTestStatelessResetToken)); - ASSERT_TRUE(data); - EXPECT_EQ(QuicFramer::GetMinStatelessResetPacketLength(), data->length()); - // Verify the first 2 bits are 01. - EXPECT_FALSE(data->data()[0] & FLAGS_LONG_HEADER); - EXPECT_TRUE(data->data()[0] & FLAGS_FIXED_BIT); - // Verify stateless reset token. - quiche::test::CompareCharArraysWithHexError( - "constructed packet", - data->data() + data->length() - kStatelessResetTokenLength, - kStatelessResetTokenLength, - AsChars(packet) + ABSL_ARRAYSIZE(packet) - kStatelessResetTokenLength, - kStatelessResetTokenLength); - - // Packets with length <= minimal stateless reset does not trigger stateless - // reset. - EXPECT_QUIC_BUG( - std::unique_ptr data2( - framer_.BuildIetfStatelessResetPacket( - FramerTestConnectionId(), - QuicFramer::GetMinStatelessResetPacketLength(), - kTestStatelessResetToken)), - "Tried to build stateless reset packet with received packet length"); - - // Do not send stateless reset >= minimal stateless reset + 1 + max - // connection ID length. - std::unique_ptr data3( - framer_.BuildIetfStatelessResetPacket(FramerTestConnectionId(), 1000, - kTestStatelessResetToken)); - ASSERT_TRUE(data3); - EXPECT_EQ(QuicFramer::GetMinStatelessResetPacketLength() + 1 + - kQuicMaxConnectionIdWithLengthPrefixLength, - data3->length()); - return; - } - // clang-format off - unsigned char packet[] = { - // type (short header, 1 byte packet number) - 0x70, - // random packet number - 0xFE, - // stateless reset token - 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, - 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, - }; // clang-format on + // Build the minimal stateless reset packet. std::unique_ptr data( - framer_.BuildIetfStatelessResetPacket(FramerTestConnectionId(), 0, - kTestStatelessResetToken)); - ASSERT_TRUE(data != nullptr); - // Skip packet number byte which is random in stateless reset packet. - quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), 1, AsChars(packet), 1); - const size_t random_bytes_length = - data->length() - kPacketHeaderTypeSize - kStatelessResetTokenLength; - EXPECT_EQ(kMinRandomBytesLengthInStatelessReset, random_bytes_length); - // Verify stateless reset token is correct. + framer_.BuildIetfStatelessResetPacket( + FramerTestConnectionId(), + QuicFramer::GetMinStatelessResetPacketLength() + 1, + kTestStatelessResetToken)); + ASSERT_TRUE(data); + EXPECT_EQ(QuicFramer::GetMinStatelessResetPacketLength(), data->length()); + // Verify the first 2 bits are 01. + EXPECT_FALSE(data->data()[0] & FLAGS_LONG_HEADER); + EXPECT_TRUE(data->data()[0] & FLAGS_FIXED_BIT); + // Verify stateless reset token. quiche::test::CompareCharArraysWithHexError( "constructed packet", data->data() + data->length() - kStatelessResetTokenLength, kStatelessResetTokenLength, AsChars(packet) + ABSL_ARRAYSIZE(packet) - kStatelessResetTokenLength, kStatelessResetTokenLength); + + // Packets with length <= minimal stateless reset does not trigger stateless + // reset. + std::unique_ptr data2( + framer_.BuildIetfStatelessResetPacket( + FramerTestConnectionId(), + QuicFramer::GetMinStatelessResetPacketLength(), + kTestStatelessResetToken)); + ASSERT_FALSE(data2); + + // Do not send stateless reset >= minimal stateless reset + 1 + max + // connection ID length. + std::unique_ptr data3( + framer_.BuildIetfStatelessResetPacket(FramerTestConnectionId(), 1000, + kTestStatelessResetToken)); + ASSERT_TRUE(data3); + EXPECT_EQ(QuicFramer::GetMinStatelessResetPacketLength() + 1 + + kQuicMaxConnectionIdWithLengthPrefixLength, + data3->length()); } TEST_P(QuicFramerTest, EncryptPacket) { @@ -9769,7 +10095,7 @@ TEST_P(QuicFramerTest, StopPacketProcessing) { 0x9A, 0xBE, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -9821,8 +10147,8 @@ TEST_P(QuicFramerTest, StopPacketProcessing) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -9938,7 +10264,7 @@ TEST_P(QuicFramerTest, IetfBlockedFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -9958,7 +10284,7 @@ TEST_P(QuicFramerTest, IetfBlockedFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -9969,7 +10295,7 @@ TEST_P(QuicFramerTest, IetfBlockedFrame) { EXPECT_EQ(kStreamOffset, visitor_.blocked_frame_.offset); - CheckFramingBoundaries(packet99, QUIC_INVALID_BLOCKED_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_BLOCKED_DATA); } TEST_P(QuicFramerTest, BuildIetfBlockedPacket) { @@ -9990,7 +10316,7 @@ TEST_P(QuicFramerTest, BuildIetfBlockedPacket) { QuicFrames frames = {QuicFrame(&frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10009,8 +10335,8 @@ TEST_P(QuicFramerTest, BuildIetfBlockedPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, IetfStreamBlockedFrame) { @@ -10021,7 +10347,7 @@ TEST_P(QuicFramerTest, IetfStreamBlockedFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10043,7 +10369,7 @@ TEST_P(QuicFramerTest, IetfStreamBlockedFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10055,7 +10381,7 @@ TEST_P(QuicFramerTest, IetfStreamBlockedFrame) { EXPECT_EQ(kStreamId, visitor_.blocked_frame_.stream_id); EXPECT_EQ(kStreamOffset, visitor_.blocked_frame_.offset); - CheckFramingBoundaries(packet99, QUIC_INVALID_STREAM_BLOCKED_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_STREAM_BLOCKED_DATA); } TEST_P(QuicFramerTest, BuildIetfStreamBlockedPacket) { @@ -10076,7 +10402,7 @@ TEST_P(QuicFramerTest, BuildIetfStreamBlockedPacket) { QuicFrames frames = {QuicFrame(&frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10097,8 +10423,8 @@ TEST_P(QuicFramerTest, BuildIetfStreamBlockedPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BiDiMaxStreamsFrame) { @@ -10109,7 +10435,7 @@ TEST_P(QuicFramerTest, BiDiMaxStreamsFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10129,7 +10455,7 @@ TEST_P(QuicFramerTest, BiDiMaxStreamsFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10140,7 +10466,7 @@ TEST_P(QuicFramerTest, BiDiMaxStreamsFrame) { EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); EXPECT_FALSE(visitor_.max_streams_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_MAX_STREAMS_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); } TEST_P(QuicFramerTest, UniDiMaxStreamsFrame) { @@ -10151,7 +10477,7 @@ TEST_P(QuicFramerTest, UniDiMaxStreamsFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10169,7 +10495,7 @@ TEST_P(QuicFramerTest, UniDiMaxStreamsFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); @@ -10181,7 +10507,7 @@ TEST_P(QuicFramerTest, UniDiMaxStreamsFrame) { EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_MAX_STREAMS_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); } TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrame) { @@ -10192,7 +10518,7 @@ TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10212,7 +10538,7 @@ TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10223,7 +10549,7 @@ TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrame) { EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_MAX_STREAMS_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); } TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrame) { @@ -10234,7 +10560,7 @@ TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10252,7 +10578,7 @@ TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); @@ -10264,7 +10590,7 @@ TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrame) { EXPECT_EQ(3u, visitor_.max_streams_frame_.stream_count); EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_MAX_STREAMS_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); } // The following four tests ensure that the framer can deserialize a stream @@ -10281,7 +10607,7 @@ TEST_P(QuicFramerTest, BiDiMaxStreamsFrameTooBig) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10298,8 +10624,8 @@ TEST_P(QuicFramerTest, BiDiMaxStreamsFrameTooBig) { }; // clang-format on - QuicEncryptedPacket encrypted(AsChars(packet99), ABSL_ARRAYSIZE(packet99), - false); + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); ASSERT_TRUE(visitor_.header_.get()); @@ -10319,7 +10645,7 @@ TEST_P(QuicFramerTest, ClientBiDiMaxStreamsFrameTooBig) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // Test runs in client mode, no connection id @@ -10335,8 +10661,8 @@ TEST_P(QuicFramerTest, ClientBiDiMaxStreamsFrameTooBig) { }; // clang-format on - QuicEncryptedPacket encrypted(AsChars(packet99), ABSL_ARRAYSIZE(packet99), - false); + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); @@ -10358,7 +10684,7 @@ TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrameTooBig) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10375,8 +10701,8 @@ TEST_P(QuicFramerTest, ServerUniDiMaxStreamsFrameTooBig) { }; // clang-format on - QuicEncryptedPacket encrypted(AsChars(packet99), ABSL_ARRAYSIZE(packet99), - false); + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10397,7 +10723,7 @@ TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrameTooBig) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // Test runs in client mode, no connection id @@ -10413,8 +10739,8 @@ TEST_P(QuicFramerTest, ClientUniDiMaxStreamsFrameTooBig) { }; // clang-format on - QuicEncryptedPacket encrypted(AsChars(packet99), ABSL_ARRAYSIZE(packet99), - false); + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); @@ -10437,7 +10763,7 @@ TEST_P(QuicFramerTest, MaxStreamsFrameZeroCount) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10451,8 +10777,8 @@ TEST_P(QuicFramerTest, MaxStreamsFrameZeroCount) { }; // clang-format on - QuicEncryptedPacket encrypted(AsChars(packet99), ABSL_ARRAYSIZE(packet99), - false); + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); EXPECT_TRUE(framer_.ProcessPacket(encrypted)); } @@ -10464,7 +10790,7 @@ TEST_P(QuicFramerTest, ServerBiDiStreamsBlockedFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10484,7 +10810,7 @@ TEST_P(QuicFramerTest, ServerBiDiStreamsBlockedFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10496,7 +10822,7 @@ TEST_P(QuicFramerTest, ServerBiDiStreamsBlockedFrame) { EXPECT_EQ(0u, visitor_.max_streams_frame_.stream_count); EXPECT_TRUE(visitor_.max_streams_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_MAX_STREAMS_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_MAX_STREAMS_DATA); } TEST_P(QuicFramerTest, BiDiStreamsBlockedFrame) { @@ -10507,7 +10833,7 @@ TEST_P(QuicFramerTest, BiDiStreamsBlockedFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10528,7 +10854,7 @@ TEST_P(QuicFramerTest, BiDiStreamsBlockedFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10540,7 +10866,7 @@ TEST_P(QuicFramerTest, BiDiStreamsBlockedFrame) { EXPECT_EQ(3u, visitor_.streams_blocked_frame_.stream_count); EXPECT_FALSE(visitor_.streams_blocked_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_STREAMS_BLOCKED_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); } TEST_P(QuicFramerTest, UniDiStreamsBlockedFrame) { @@ -10551,7 +10877,7 @@ TEST_P(QuicFramerTest, UniDiStreamsBlockedFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10572,7 +10898,7 @@ TEST_P(QuicFramerTest, UniDiStreamsBlockedFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10583,7 +10909,7 @@ TEST_P(QuicFramerTest, UniDiStreamsBlockedFrame) { EXPECT_EQ(3u, visitor_.streams_blocked_frame_.stream_count); EXPECT_TRUE(visitor_.streams_blocked_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_STREAMS_BLOCKED_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); } TEST_P(QuicFramerTest, ClientUniDiStreamsBlockedFrame) { @@ -10594,7 +10920,7 @@ TEST_P(QuicFramerTest, ClientUniDiStreamsBlockedFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10613,7 +10939,7 @@ TEST_P(QuicFramerTest, ClientUniDiStreamsBlockedFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); @@ -10625,7 +10951,7 @@ TEST_P(QuicFramerTest, ClientUniDiStreamsBlockedFrame) { EXPECT_EQ(3u, visitor_.streams_blocked_frame_.stream_count); EXPECT_TRUE(visitor_.streams_blocked_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_STREAMS_BLOCKED_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); } // Check that when we get a STREAMS_BLOCKED frame that specifies too large @@ -10640,7 +10966,7 @@ TEST_P(QuicFramerTest, StreamsBlockedFrameTooBig) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // Test runs in client mode, no connection id @@ -10656,8 +10982,8 @@ TEST_P(QuicFramerTest, StreamsBlockedFrameTooBig) { }; // clang-format on - QuicEncryptedPacket encrypted(AsChars(packet99), ABSL_ARRAYSIZE(packet99), - false); + QuicEncryptedPacket encrypted(AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf), false); QuicFramerPeer::SetPerspective(&framer_, Perspective::IS_CLIENT); EXPECT_FALSE(framer_.ProcessPacket(encrypted)); @@ -10676,7 +11002,7 @@ TEST_P(QuicFramerTest, StreamsBlockedFrameZeroCount) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10697,7 +11023,7 @@ TEST_P(QuicFramerTest, StreamsBlockedFrameZeroCount) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10709,7 +11035,7 @@ TEST_P(QuicFramerTest, StreamsBlockedFrameZeroCount) { EXPECT_EQ(0u, visitor_.streams_blocked_frame_.stream_count); EXPECT_TRUE(visitor_.streams_blocked_frame_.unidirectional); - CheckFramingBoundaries(packet99, QUIC_STREAMS_BLOCKED_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_STREAMS_BLOCKED_DATA); } TEST_P(QuicFramerTest, BuildBiDiStreamsBlockedPacket) { @@ -10731,7 +11057,7 @@ TEST_P(QuicFramerTest, BuildBiDiStreamsBlockedPacket) { QuicFrames frames = {QuicFrame(frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10750,8 +11076,8 @@ TEST_P(QuicFramerTest, BuildBiDiStreamsBlockedPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BuildUniStreamsBlockedPacket) { @@ -10773,7 +11099,7 @@ TEST_P(QuicFramerTest, BuildUniStreamsBlockedPacket) { QuicFrames frames = {QuicFrame(frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10792,8 +11118,8 @@ TEST_P(QuicFramerTest, BuildUniStreamsBlockedPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BuildBiDiMaxStreamsPacket) { @@ -10815,7 +11141,7 @@ TEST_P(QuicFramerTest, BuildBiDiMaxStreamsPacket) { QuicFrames frames = {QuicFrame(frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10834,8 +11160,8 @@ TEST_P(QuicFramerTest, BuildBiDiMaxStreamsPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, BuildUniDiMaxStreamsPacket) { @@ -10860,7 +11186,7 @@ TEST_P(QuicFramerTest, BuildUniDiMaxStreamsPacket) { QuicFrames frames = {QuicFrame(frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -10879,8 +11205,8 @@ TEST_P(QuicFramerTest, BuildUniDiMaxStreamsPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, NewConnectionIdFrame) { @@ -10890,7 +11216,7 @@ TEST_P(QuicFramerTest, NewConnectionIdFrame) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10919,7 +11245,7 @@ TEST_P(QuicFramerTest, NewConnectionIdFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10939,7 +11265,7 @@ TEST_P(QuicFramerTest, NewConnectionIdFrame) { ASSERT_EQ(0u, visitor_.ack_frames_.size()); - CheckFramingBoundaries(packet99, QUIC_INVALID_NEW_CONNECTION_ID_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_NEW_CONNECTION_ID_DATA); } TEST_P(QuicFramerTest, NewConnectionIdFrameVariableLength) { @@ -10949,7 +11275,7 @@ TEST_P(QuicFramerTest, NewConnectionIdFrameVariableLength) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -10978,7 +11304,7 @@ TEST_P(QuicFramerTest, NewConnectionIdFrameVariableLength) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -10998,7 +11324,7 @@ TEST_P(QuicFramerTest, NewConnectionIdFrameVariableLength) { ASSERT_EQ(0u, visitor_.ack_frames_.size()); - CheckFramingBoundaries(packet99, QUIC_INVALID_NEW_CONNECTION_ID_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_NEW_CONNECTION_ID_DATA); } // Verifies that parsing a NEW_CONNECTION_ID frame with a length above the @@ -11010,7 +11336,7 @@ TEST_P(QuicFramerTest, InvalidLongNewConnectionIdFrame) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -11046,7 +11372,7 @@ TEST_P(QuicFramerTest, InvalidLongNewConnectionIdFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_NEW_CONNECTION_ID_DATA)); EXPECT_EQ("Invalid new connection ID length for version.", @@ -11062,7 +11388,7 @@ TEST_P(QuicFramerTest, InvalidRetirePriorToNewConnectionIdFrame) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -11091,7 +11417,7 @@ TEST_P(QuicFramerTest, InvalidRetirePriorToNewConnectionIdFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_FALSE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsError(QUIC_INVALID_NEW_CONNECTION_ID_DATA)); EXPECT_EQ("Retire_prior_to > sequence_number.", framer_.detailed_error()); @@ -11120,7 +11446,7 @@ TEST_P(QuicFramerTest, BuildNewConnectionIdFramePacket) { QuicFrames frames = {QuicFrame(&frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -11148,8 +11474,8 @@ TEST_P(QuicFramerTest, BuildNewConnectionIdFramePacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, NewTokenFrame) { @@ -11255,7 +11581,7 @@ TEST_P(QuicFramerTest, IetfStopSendingFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -11277,7 +11603,7 @@ TEST_P(QuicFramerTest, IetfStopSendingFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -11292,7 +11618,7 @@ TEST_P(QuicFramerTest, IetfStopSendingFrame) { EXPECT_EQ(static_cast(0x7654), visitor_.stop_sending_frame_.ietf_error_code); - CheckFramingBoundaries(packet99, QUIC_INVALID_STOP_SENDING_FRAME_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_STOP_SENDING_FRAME_DATA); } TEST_P(QuicFramerTest, BuildIetfStopSendingPacket) { @@ -11315,7 +11641,7 @@ TEST_P(QuicFramerTest, BuildIetfStopSendingPacket) { QuicFrames frames = {QuicFrame(&frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -11336,8 +11662,8 @@ TEST_P(QuicFramerTest, BuildIetfStopSendingPacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, IetfPathChallengeFrame) { @@ -11348,7 +11674,7 @@ TEST_P(QuicFramerTest, IetfPathChallengeFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -11368,7 +11694,7 @@ TEST_P(QuicFramerTest, IetfPathChallengeFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -11380,7 +11706,7 @@ TEST_P(QuicFramerTest, IetfPathChallengeFrame) { EXPECT_EQ(QuicPathFrameBuffer({{0, 1, 2, 3, 4, 5, 6, 7}}), visitor_.path_challenge_frame_.data_buffer); - CheckFramingBoundaries(packet99, QUIC_INVALID_PATH_CHALLENGE_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_PATH_CHALLENGE_DATA); } TEST_P(QuicFramerTest, BuildIetfPathChallengePacket) { @@ -11400,7 +11726,7 @@ TEST_P(QuicFramerTest, BuildIetfPathChallengePacket) { QuicFrames frames = {QuicFrame(&frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -11419,8 +11745,8 @@ TEST_P(QuicFramerTest, BuildIetfPathChallengePacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, IetfPathResponseFrame) { @@ -11431,7 +11757,7 @@ TEST_P(QuicFramerTest, IetfPathResponseFrame) { SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -11451,7 +11777,7 @@ TEST_P(QuicFramerTest, IetfPathResponseFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -11463,7 +11789,7 @@ TEST_P(QuicFramerTest, IetfPathResponseFrame) { EXPECT_EQ(QuicPathFrameBuffer({{0, 1, 2, 3, 4, 5, 6, 7}}), visitor_.path_response_frame_.data_buffer); - CheckFramingBoundaries(packet99, QUIC_INVALID_PATH_RESPONSE_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_PATH_RESPONSE_DATA); } TEST_P(QuicFramerTest, BuildIetfPathResponsePacket) { @@ -11483,7 +11809,7 @@ TEST_P(QuicFramerTest, BuildIetfPathResponsePacket) { QuicFrames frames = {QuicFrame(&frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -11502,8 +11828,8 @@ TEST_P(QuicFramerTest, BuildIetfPathResponsePacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, GetRetransmittableControlFrameSize) { @@ -12233,7 +12559,7 @@ TEST_P(QuicFramerTest, RetireConnectionIdFrame) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - PacketFragments packet99 = { + PacketFragments packet_ietf = { // type (short header, 4 byte packet number) {"", {0x43}}, @@ -12253,7 +12579,7 @@ TEST_P(QuicFramerTest, RetireConnectionIdFrame) { // clang-format on std::unique_ptr encrypted( - AssemblePacketFromFragments(packet99)); + AssemblePacketFromFragments(packet_ietf)); EXPECT_TRUE(framer_.ProcessPacket(*encrypted)); EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -12268,7 +12594,7 @@ TEST_P(QuicFramerTest, RetireConnectionIdFrame) { ASSERT_EQ(0u, visitor_.ack_frames_.size()); - CheckFramingBoundaries(packet99, QUIC_INVALID_RETIRE_CONNECTION_ID_DATA); + CheckFramingBoundaries(packet_ietf, QUIC_INVALID_RETIRE_CONNECTION_ID_DATA); } TEST_P(QuicFramerTest, BuildRetireConnectionIdFramePacket) { @@ -12289,7 +12615,7 @@ TEST_P(QuicFramerTest, BuildRetireConnectionIdFramePacket) { QuicFrames frames = {QuicFrame(&frame)}; // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -12308,8 +12634,8 @@ TEST_P(QuicFramerTest, BuildRetireConnectionIdFramePacket) { ASSERT_TRUE(data != nullptr); quiche::test::CompareCharArraysWithHexError( - "constructed packet", data->data(), data->length(), AsChars(packet99), - ABSL_ARRAYSIZE(packet99)); + "constructed packet", data->data(), data->length(), AsChars(packet_ietf), + ABSL_ARRAYSIZE(packet_ietf)); } TEST_P(QuicFramerTest, AckFrameWithInvalidLargestObserved) { @@ -12355,7 +12681,7 @@ TEST_P(QuicFramerTest, AckFrameWithInvalidLargestObserved) { 0x00 }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -12379,8 +12705,8 @@ TEST_P(QuicFramerTest, AckFrameWithInvalidLargestObserved) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; } @@ -12433,7 +12759,7 @@ TEST_P(QuicFramerTest, FirstAckBlockJustUnderFlow) { 0x00 }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -12457,8 +12783,8 @@ TEST_P(QuicFramerTest, FirstAckBlockJustUnderFlow) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -12533,7 +12859,7 @@ TEST_P(QuicFramerTest, ThirdAckBlockJustUnderflow) { 0x00 }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -12565,8 +12891,8 @@ TEST_P(QuicFramerTest, ThirdAckBlockJustUnderflow) { unsigned char* p = packet; size_t p_size = ABSL_ARRAYSIZE(packet); if (VersionHasIetfQuicFrames(framer_.transport_version())) { - p = packet99; - p_size = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_size = ABSL_ARRAYSIZE(packet_ietf); } else if (framer_.version().HasIetfInvariantHeader()) { p = packet46; p_size = ABSL_ARRAYSIZE(packet46); @@ -12647,7 +12973,7 @@ TEST_P(QuicFramerTest, CoalescedPacket) { 'O', '_', 'W', 'O', 'R', 'L', 'D', '?', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // first coalesced packet // public flags (long header with packet type ZERO_RTT_PROTECTED and // 4-byte packet number) @@ -12712,8 +13038,8 @@ TEST_P(QuicFramerTest, CoalescedPacket) { unsigned char* p = packet; size_t p_length = ABSL_ARRAYSIZE(packet); if (framer_.version().HasIetfQuicFrames()) { - p = packet99; - p_length = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); } QuicEncryptedPacket encrypted(AsChars(p), p_length, false); @@ -12789,7 +13115,7 @@ TEST_P(QuicFramerTest, CoalescedPacketWithUdpPadding) { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // first coalesced packet // public flags (long header with packet type ZERO_RTT_PROTECTED and // 4-byte packet number) @@ -12831,8 +13157,8 @@ TEST_P(QuicFramerTest, CoalescedPacketWithUdpPadding) { unsigned char* p = packet; size_t p_length = ABSL_ARRAYSIZE(packet); if (framer_.version().HasIetfQuicFrames()) { - p = packet99; - p_length = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); } QuicEncryptedPacket encrypted(AsChars(p), p_length, false); @@ -12917,7 +13243,7 @@ TEST_P(QuicFramerTest, CoalescedPacketWithDifferentVersion) { 'O', '_', 'W', 'O', 'R', 'L', 'D', '?', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // first coalesced packet // public flags (long header with packet type ZERO_RTT_PROTECTED and // 4-byte packet number) @@ -12982,8 +13308,8 @@ TEST_P(QuicFramerTest, CoalescedPacketWithDifferentVersion) { unsigned char* p = packet; size_t p_length = ABSL_ARRAYSIZE(packet); if (framer_.version().HasIetfQuicFrames()) { - p = packet99; - p_length = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); } QuicEncryptedPacket encrypted(AsChars(p), p_length, false); @@ -13289,7 +13615,7 @@ TEST_P(QuicFramerTest, UndecryptableCoalescedPacket) { 'O', '_', 'W', 'O', 'R', 'L', 'D', '?', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // first coalesced packet // public flags (long header with packet type HANDSHAKE and // 4-byte packet number) @@ -13355,8 +13681,8 @@ TEST_P(QuicFramerTest, UndecryptableCoalescedPacket) { unsigned char* p = packet; size_t p_length = ABSL_ARRAYSIZE(packet); if (framer_.version().HasIetfQuicFrames()) { - p = packet99; - p_length = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); } QuicEncryptedPacket encrypted(AsChars(p), p_length, false); @@ -13457,7 +13783,7 @@ TEST_P(QuicFramerTest, MismatchedCoalescedPacket) { 'O', '_', 'W', 'O', 'R', 'L', 'D', '?', }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // first coalesced packet // public flags (long header with packet type ZERO_RTT_PROTECTED and // 4-byte packet number) @@ -13522,8 +13848,8 @@ TEST_P(QuicFramerTest, MismatchedCoalescedPacket) { unsigned char* p = packet; size_t p_length = ABSL_ARRAYSIZE(packet); if (framer_.version().HasIetfQuicFrames()) { - p = packet99; - p_length = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); } QuicEncryptedPacket encrypted(AsChars(p), p_length, false); @@ -13586,7 +13912,7 @@ TEST_P(QuicFramerTest, InvalidCoalescedPacket) { 0xD3, // version would be here but we cut off the invalid coalesced header. }; - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // first coalesced packet // public flags (long header with packet type ZERO_RTT_PROTECTED and // 4-byte packet number) @@ -13627,8 +13953,8 @@ TEST_P(QuicFramerTest, InvalidCoalescedPacket) { unsigned char* p = packet; size_t p_length = ABSL_ARRAYSIZE(packet); if (framer_.version().HasIetfQuicFrames()) { - p = packet99; - p_length = ABSL_ARRAYSIZE(packet99); + p = packet_ietf; + p_length = ABSL_ARRAYSIZE(packet_ietf); } QuicEncryptedPacket encrypted(AsChars(p), p_length, false); @@ -13741,9 +14067,9 @@ TEST_P(QuicFramerTest, PacketHeaderWithVariableLengthConnectionId) { return; } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); - char connection_id_bytes[9] = {0xFE, 0xDC, 0xBA, 0x98, 0x76, - 0x54, 0x32, 0x10, 0x42}; - QuicConnectionId connection_id(connection_id_bytes, + uint8_t connection_id_bytes[9] = {0xFE, 0xDC, 0xBA, 0x98, 0x76, + 0x54, 0x32, 0x10, 0x42}; + QuicConnectionId connection_id(reinterpret_cast(connection_id_bytes), sizeof(connection_id_bytes)); QuicFramerPeer::SetLargestPacketNumber(&framer_, kPacketNumber - 2); QuicFramerPeer::SetExpectedServerConnectionIDLength(&framer_, @@ -13820,7 +14146,7 @@ TEST_P(QuicFramerTest, MultiplePacketNumberSpaces) { // padding frame 0x00, }; - unsigned char long_header_packet99[] = { + unsigned char long_header_packet_ietf[] = { // public flags (long header with packet type ZERO_RTT_PROTECTED and // 4-byte packet number) 0xD3, @@ -13855,8 +14181,8 @@ TEST_P(QuicFramerTest, MultiplePacketNumberSpaces) { ABSL_ARRAYSIZE(long_header_packet), false))); } else { EXPECT_TRUE(framer_.ProcessPacket( - QuicEncryptedPacket(AsChars(long_header_packet99), - ABSL_ARRAYSIZE(long_header_packet99), false))); + QuicEncryptedPacket(AsChars(long_header_packet_ietf), + ABSL_ARRAYSIZE(long_header_packet_ietf), false))); } EXPECT_THAT(framer_.error(), IsQuicNoError()); @@ -14042,7 +14368,7 @@ TEST_P(QuicFramerTest, ProcessMismatchedHeaderVersion) { TEST_P(QuicFramerTest, WriteClientVersionNegotiationProbePacket) { // clang-format off - static const char expected_packet[1200] = { + static const uint8_t expected_packet[1200] = { // IETF long header with fixed bit set, type initial, all-0 encrypted bits. 0xc0, // Version, part of the IETF space reserved for negotiation. @@ -14091,9 +14417,9 @@ TEST_P(QuicFramerTest, WriteClientVersionNegotiationProbePacket) { EXPECT_TRUE(QuicFramer::WriteClientVersionNegotiationProbePacket( packet, sizeof(packet), destination_connection_id_bytes, sizeof(destination_connection_id_bytes))); - quiche::test::CompareCharArraysWithHexError("constructed packet", packet, - sizeof(packet), expected_packet, - sizeof(expected_packet)); + quiche::test::CompareCharArraysWithHexError( + "constructed packet", packet, sizeof(packet), + reinterpret_cast(expected_packet), sizeof(expected_packet)); QuicEncryptedPacket encrypted(reinterpret_cast(packet), sizeof(packet), false); if (!framer_.version().HasLengthPrefixedConnectionIds()) { @@ -14113,7 +14439,7 @@ TEST_P(QuicFramerTest, WriteClientVersionNegotiationProbePacket) { TEST_P(QuicFramerTest, DispatcherParseOldClientVersionNegotiationProbePacket) { // clang-format off - static const char packet[1200] = { + static const uint8_t packet[1200] = { // IETF long header with fixed bit set, type initial, all-0 encrypted bits. 0xc0, // Version, part of the IETF space reserved for negotiation. @@ -14169,14 +14495,13 @@ TEST_P(QuicFramerTest, DispatcherParseOldClientVersionNegotiationProbePacket) { ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); QuicConnectionId destination_connection_id = TestConnectionId(1); QuicConnectionId source_connection_id = TestConnectionId(2); - bool retry_token_present = true; - absl::string_view retry_token; + absl::optional retry_token; std::string detailed_error = "foobar"; QuicErrorCode header_parse_result = QuicFramer::ParsePublicHeaderDispatcher( encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_present, &has_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); EXPECT_THAT(header_parse_result, IsQuicNoError()); EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); EXPECT_TRUE(version_present); @@ -14184,13 +14509,13 @@ TEST_P(QuicFramerTest, DispatcherParseOldClientVersionNegotiationProbePacket) { EXPECT_EQ(0xcabadaba, version_label); EXPECT_EQ(expected_destination_connection_id, destination_connection_id); EXPECT_EQ(EmptyQuicConnectionId(), source_connection_id); - EXPECT_FALSE(retry_token_present); + EXPECT_FALSE(retry_token.has_value()); EXPECT_EQ("", detailed_error); } TEST_P(QuicFramerTest, DispatcherParseClientVersionNegotiationProbePacket) { // clang-format off - static const char packet[1200] = { + static const uint8_t packet[1200] = { // IETF long header with fixed bit set, type initial, all-0 encrypted bits. 0xc0, // Version, part of the IETF space reserved for negotiation. @@ -14248,14 +14573,13 @@ TEST_P(QuicFramerTest, DispatcherParseClientVersionNegotiationProbePacket) { ParsedQuicVersion parsed_version = UnsupportedQuicVersion(); QuicConnectionId destination_connection_id = TestConnectionId(1); QuicConnectionId source_connection_id = TestConnectionId(2); - bool retry_token_present = true; - absl::string_view retry_token; + absl::optional retry_token; std::string detailed_error = "foobar"; QuicErrorCode header_parse_result = QuicFramer::ParsePublicHeaderDispatcher( encrypted, kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_present, &has_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); EXPECT_THAT(header_parse_result, IsQuicNoError()); EXPECT_EQ(IETF_QUIC_LONG_HEADER_PACKET, format); EXPECT_TRUE(version_present); @@ -14268,7 +14592,7 @@ TEST_P(QuicFramerTest, DispatcherParseClientVersionNegotiationProbePacket) { TEST_P(QuicFramerTest, ParseServerVersionNegotiationProbeResponse) { // clang-format off - const char packet[] = { + const uint8_t packet[] = { // IETF long header with fixed bit set, type initial, all-0 encrypted bits. 0xc0, // Version of 0, indicating version negotiation. @@ -14644,7 +14968,7 @@ TEST_P(QuicFramerTest, OverlyLargeAckDelay) { } SetDecrypterLevel(ENCRYPTION_FORWARD_SECURE); // clang-format off - unsigned char packet99[] = { + unsigned char packet_ietf[] = { // type (short header, 4 byte packet number) 0x43, // connection_id @@ -14665,8 +14989,8 @@ TEST_P(QuicFramerTest, OverlyLargeAckDelay) { }; // clang-format on - framer_.ProcessPacket( - QuicEncryptedPacket(AsChars(packet99), ABSL_ARRAYSIZE(packet99), false)); + framer_.ProcessPacket(QuicEncryptedPacket( + AsChars(packet_ietf), ABSL_ARRAYSIZE(packet_ietf), false)); ASSERT_EQ(1u, visitor_.ack_frames_.size()); // Verify ack_delay_time is set correctly. EXPECT_EQ(QuicTime::Delta::Infinite(), @@ -15230,7 +15554,6 @@ TEST_P(QuicFramerTest, KeyUpdateOnFirstReceivedPacket) { // Key update is only used in QUIC+TLS. return; } - SetQuicReloadableFlag(quic_fix_key_update_on_first_packet, true); ASSERT_TRUE(framer_.version().KnowsWhichDecrypterToUse()); // Doesn't use SetDecrypterLevel since we want to use StrictTaggingDecrypter // instead of TestDecrypter. @@ -15262,8 +15585,7 @@ TEST_P(QuicFramerTest, KeyUpdateOnFirstReceivedPacket) { } TEST_P(QuicFramerTest, ErrorWhenUnexpectedFrameTypeEncountered) { - if (!GetQuicReloadableFlag(quic_reject_unexpected_ietf_frame_types) || - !VersionHasIetfQuicFrames(framer_.transport_version()) || + if (!VersionHasIetfQuicFrames(framer_.transport_version()) || !QuicVersionHasLongHeaderLengths(framer_.transport_version()) || !framer_.version().HasLongHeaderLengths()) { return; diff --git a/gquiche/quic/core/quic_idle_network_detector.cc b/gquiche/quic/core/quic_idle_network_detector.cc index 3a92c548..4ab32605 100644 --- a/gquiche/quic/core/quic_idle_network_detector.cc +++ b/gquiche/quic/core/quic_idle_network_detector.cc @@ -5,15 +5,17 @@ #include "gquiche/quic/core/quic_idle_network_detector.h" #include "gquiche/quic/core/quic_constants.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" namespace quic { namespace { -class AlarmDelegate : public QuicAlarm::Delegate { +class AlarmDelegate : public QuicAlarm::DelegateWithContext { public: - explicit AlarmDelegate(QuicIdleNetworkDetector* detector) - : detector_(detector) {} + explicit AlarmDelegate(QuicIdleNetworkDetector* detector, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), detector_(detector) {} AlarmDelegate(const AlarmDelegate&) = delete; AlarmDelegate& operator=(const AlarmDelegate&) = delete; @@ -26,18 +28,16 @@ class AlarmDelegate : public QuicAlarm::Delegate { } // namespace QuicIdleNetworkDetector::QuicIdleNetworkDetector( - Delegate* delegate, - QuicTime now, - QuicConnectionArena* arena, - QuicAlarmFactory* alarm_factory) + Delegate* delegate, QuicTime now, QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, QuicConnectionContext* context) : delegate_(delegate), start_time_(now), handshake_timeout_(QuicTime::Delta::Infinite()), time_of_last_received_packet_(now), time_of_first_packet_sent_after_receiving_(QuicTime::Zero()), idle_network_timeout_(QuicTime::Delta::Infinite()), - alarm_( - alarm_factory->CreateAlarm(arena->New(this), arena)) {} + alarm_(alarm_factory->CreateAlarm( + arena->New(this, context), arena)) {} void QuicIdleNetworkDetector::OnAlarm() { if (handshake_timeout_.IsInfinite()) { @@ -66,9 +66,10 @@ void QuicIdleNetworkDetector::SetTimeouts( } void QuicIdleNetworkDetector::StopDetection() { - alarm_->Cancel(); + alarm_->PermanentCancel(); handshake_timeout_ = QuicTime::Delta::Infinite(); idle_network_timeout_ = QuicTime::Delta::Infinite(); + stopped_ = true; } void QuicIdleNetworkDetector::OnPacketSent(QuicTime now, @@ -94,6 +95,14 @@ void QuicIdleNetworkDetector::OnPacketReceived(QuicTime now) { } void QuicIdleNetworkDetector::SetAlarm() { + if (stopped_) { + // TODO(wub): If this QUIC_BUG fires, it indicates a problem in the + // QuicConnection, which somehow called this function while disconnected. + // That problem needs to be fixed. + QUIC_BUG(quic_idle_detector_set_alarm_after_stopped) + << "SetAlarm called after stopped"; + return; + } // Set alarm to the nearer deadline. QuicTime new_deadline = QuicTime::Zero(); if (!handshake_timeout_.IsInfinite()) { diff --git a/gquiche/quic/core/quic_idle_network_detector.h b/gquiche/quic/core/quic_idle_network_detector.h index f4c76925..f8f7af41 100644 --- a/gquiche/quic/core/quic_idle_network_detector.h +++ b/gquiche/quic/core/quic_idle_network_detector.h @@ -10,6 +10,7 @@ #include "gquiche/quic/core/quic_one_block_arena.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/quic/platform/api/quic_flags.h" namespace quic { @@ -35,10 +36,10 @@ class QUIC_EXPORT_PRIVATE QuicIdleNetworkDetector { virtual void OnIdleNetworkDetected() = 0; }; - QuicIdleNetworkDetector(Delegate* delegate, - QuicTime now, + QuicIdleNetworkDetector(Delegate* delegate, QuicTime now, QuicConnectionArena* arena, - QuicAlarmFactory* alarm_factory); + QuicAlarmFactory* alarm_factory, + QuicConnectionContext* context); void OnAlarm(); @@ -46,6 +47,7 @@ class QUIC_EXPORT_PRIVATE QuicIdleNetworkDetector { void SetTimeouts(QuicTime::Delta handshake_timeout, QuicTime::Delta idle_network_timeout); + // Stop the detection once and for all. void StopDetection(); // Called when a packet gets sent. @@ -106,6 +108,9 @@ class QUIC_EXPORT_PRIVATE QuicIdleNetworkDetector { QuicArenaScopedPtr alarm_; bool shorter_idle_timeout_on_sent_packet_ = false; + + // Whether |StopDetection| has been called. + bool stopped_ = false; }; } // namespace quic diff --git a/gquiche/quic/core/quic_idle_network_detector_test.cc b/gquiche/quic/core/quic_idle_network_detector_test.cc index 6284fe67..151fd209 100644 --- a/gquiche/quic/core/quic_idle_network_detector_test.cc +++ b/gquiche/quic/core/quic_idle_network_detector_test.cc @@ -5,6 +5,7 @@ #include "gquiche/quic/core/quic_idle_network_detector.h" #include "gquiche/quic/core/quic_one_block_arena.h" +#include "gquiche/quic/platform/api/quic_expect_bug.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" @@ -31,7 +32,8 @@ class QuicIdleNetworkDetectorTest : public QuicTest { QuicIdleNetworkDetectorTest() { clock_.AdvanceTime(QuicTime::Delta::FromSeconds(1)); detector_ = std::make_unique( - &delegate_, clock_.Now(), &arena_, &alarm_factory_); + &delegate_, clock_.Now(), &arena_, &alarm_factory_, + /*context=*/nullptr); alarm_ = static_cast( QuicIdleNetworkDetectorTestPeer::GetAlarm(detector_.get())); } @@ -181,6 +183,17 @@ TEST_F(QuicIdleNetworkDetectorTest, ShorterIdleTimeoutOnSentPacket) { EXPECT_EQ(clock_.Now() + QuicTime::Delta::FromSeconds(2), alarm_->deadline()); } +TEST_F(QuicIdleNetworkDetectorTest, NoAlarmAfterStopped) { + detector_->StopDetection(); + + EXPECT_QUIC_BUG( + detector_->SetTimeouts( + /*handshake_timeout=*/QuicTime::Delta::FromSeconds(30), + /*idle_network_timeout=*/QuicTime::Delta::FromSeconds(20)), + "SetAlarm called after stopped"); + EXPECT_FALSE(alarm_->IsSet()); +} + } // namespace } // namespace test diff --git a/gquiche/quic/core/quic_interval_deque.h b/gquiche/quic/core/quic_interval_deque.h index 9448ecd8..e677c4c0 100644 --- a/gquiche/quic/core/quic_interval_deque.h +++ b/gquiche/quic/core/quic_interval_deque.h @@ -8,12 +8,12 @@ #include #include "absl/types/optional.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_interval.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -137,7 +137,7 @@ class QuicIntervalDequePeer; // // cached_index -> 1 // // container -> {{2, [25, 30)}, {3, [35, 50)}} -template > +template > class QUIC_NO_EXPORT QuicIntervalDeque { public: class QUIC_NO_EXPORT Iterator { @@ -363,7 +363,7 @@ void QuicIntervalDeque::PushBackUniversal(U&& item) { // Adding an empty interval is a bug. if (interval.Empty()) { QUIC_BUG(quic_bug_10862_3) - << "Trying to save empty interval to QuicCircularDeque."; + << "Trying to save empty interval to quiche::QuicheCircularDeque."; return; } container_.push_back(std::forward(item)); diff --git a/gquiche/quic/core/quic_interval_set.h b/gquiche/quic/core/quic_interval_set.h index 2fcf0506..a0fff66e 100644 --- a/gquiche/quic/core/quic_interval_set.h +++ b/gquiche/quic/core/quic_interval_set.h @@ -85,9 +85,7 @@ class QUIC_NO_EXPORT QuicIntervalSet { bool operator()(T&& point, const value_type& a) const; }; - using Set = QuicOrderedSet>; + using Set = QuicSmallOrderedSet; public: using const_iterator = typename Set::const_iterator; diff --git a/gquiche/quic/core/quic_legacy_version_encapsulator.cc b/gquiche/quic/core/quic_legacy_version_encapsulator.cc index 30a71381..f76f1c1a 100644 --- a/gquiche/quic/core/quic_legacy_version_encapsulator.cc +++ b/gquiche/quic/core/quic_legacy_version_encapsulator.cc @@ -7,7 +7,6 @@ #include "gquiche/quic/core/crypto/crypto_protocol.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { diff --git a/gquiche/quic/core/quic_legacy_version_encapsulator_test.cc b/gquiche/quic/core/quic_legacy_version_encapsulator_test.cc index 142b6854..85446799 100644 --- a/gquiche/quic/core/quic_legacy_version_encapsulator_test.cc +++ b/gquiche/quic/core/quic_legacy_version_encapsulator_test.cc @@ -9,7 +9,6 @@ #include "gquiche/quic/platform/api/quic_expect_bug.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { namespace test { @@ -42,15 +41,14 @@ class QuicLegacyVersionEncapsulatorTest QuicVersionLabel version_label; ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported(); QuicConnectionId destination_connection_id, source_connection_id; - bool retry_token_present; - absl::string_view retry_token; + absl::optional retry_token; std::string detailed_error; const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( QuicEncryptedPacket(outer_buffer_, encapsulated_length_), kQuicDefaultConnectionIdLength, &format, &long_packet_type, &version_present, &has_length_prefix, &version_label, &parsed_version, - &destination_connection_id, &source_connection_id, &retry_token_present, - &retry_token, &detailed_error); + &destination_connection_id, &source_connection_id, &retry_token, + &detailed_error); ASSERT_THAT(error, IsQuicNoError()) << detailed_error; EXPECT_EQ(format, GOOGLE_QUIC_PACKET); EXPECT_TRUE(version_present); @@ -58,7 +56,7 @@ class QuicLegacyVersionEncapsulatorTest EXPECT_EQ(parsed_version, LegacyVersionForEncapsulation()); EXPECT_EQ(destination_connection_id, server_connection_id_); EXPECT_EQ(source_connection_id, EmptyQuicConnectionId()); - EXPECT_FALSE(retry_token_present); + EXPECT_FALSE(retry_token.has_value()); EXPECT_TRUE(detailed_error.empty()); } diff --git a/gquiche/quic/core/quic_linux_socket_utils.h b/gquiche/quic/core/quic_linux_socket_utils.h index 4969ebb8..6113678b 100644 --- a/gquiche/quic/core/quic_linux_socket_utils.h +++ b/gquiche/quic/core/quic_linux_socket_utils.h @@ -148,7 +148,7 @@ struct QUIC_EXPORT_PRIVATE BufferedWrite { // multiple packets at once via ::sendmmsg. // // Example: -// QuicCircularDeque buffered_writes; +// quiche::QuicheCircularDeque buffered_writes; // ... (Populate buffered_writes) ... // // QuicMMsgHdr mhdr( diff --git a/gquiche/quic/core/quic_linux_socket_utils_test.cc b/gquiche/quic/core/quic_linux_socket_utils_test.cc index d22afdb2..af9fcf0f 100644 --- a/gquiche/quic/core/quic_linux_socket_utils_test.cc +++ b/gquiche/quic/core/quic_linux_socket_utils_test.cc @@ -12,9 +12,9 @@ #include -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/quic_mock_syscall_wrapper.h" +#include "gquiche/common/quiche_circular_deque.h" using testing::_; using testing::InSequence; @@ -28,8 +28,8 @@ class QuicLinuxSocketUtilsTest : public QuicTest { protected: WriteResult TestWriteMultiplePackets( int fd, - const QuicCircularDeque::const_iterator& first, - const QuicCircularDeque::const_iterator& last, + const quiche::QuicheCircularDeque::const_iterator& first, + const quiche::QuicheCircularDeque::const_iterator& last, int* num_packets_sent) { QuicMMsgHdr mhdr( first, last, kCmsgSpaceForIp, @@ -171,7 +171,7 @@ TEST_F(QuicLinuxSocketUtilsTest, QuicMsgHdr) { } TEST_F(QuicLinuxSocketUtilsTest, QuicMMsgHdr) { - QuicCircularDeque buffered_writes; + quiche::QuicheCircularDeque buffered_writes; char packet_buf1[1024]; char packet_buf2[512]; buffered_writes.emplace_back( @@ -205,7 +205,7 @@ TEST_F(QuicLinuxSocketUtilsTest, QuicMMsgHdr) { TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_NoPacketsToSend) { int num_packets_sent; - QuicCircularDeque buffered_writes; + quiche::QuicheCircularDeque buffered_writes; EXPECT_CALL(mock_syscalls_, Sendmmsg(_, _, _, _)).Times(0); @@ -216,7 +216,7 @@ TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_NoPacketsToSend) { TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteBlocked) { int num_packets_sent; - QuicCircularDeque buffered_writes; + quiche::QuicheCircularDeque buffered_writes; buffered_writes.emplace_back(nullptr, 0, QuicIpAddress(), QuicSocketAddress(QuicIpAddress::Any4(), 0)); @@ -235,7 +235,7 @@ TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteBlocked) { TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteError) { int num_packets_sent; - QuicCircularDeque buffered_writes; + quiche::QuicheCircularDeque buffered_writes; buffered_writes.emplace_back(nullptr, 0, QuicIpAddress(), QuicSocketAddress(QuicIpAddress::Any4(), 0)); @@ -254,7 +254,7 @@ TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteError) { TEST_F(QuicLinuxSocketUtilsTest, WriteMultiplePackets_WriteSuccess) { int num_packets_sent; - QuicCircularDeque buffered_writes; + quiche::QuicheCircularDeque buffered_writes; const int kNumBufferedWrites = 10; static_assert(kNumBufferedWrites < 256, "Must be less than 256"); std::vector buffer_holder; diff --git a/gquiche/quic/core/quic_lru_cache.h b/gquiche/quic/core/quic_lru_cache.h index a54a1dcd..a14cb5d9 100644 --- a/gquiche/quic/core/quic_lru_cache.h +++ b/gquiche/quic/core/quic_lru_cache.h @@ -7,11 +7,11 @@ #include -#include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -19,13 +19,36 @@ namespace quic { // This cache CANNOT be shared by multiple threads (even with locks) because // Value* returned by Lookup() can be invalid if the entry is evicted by other // threads. -template +template , + class Eq = std::equal_to> class QUIC_NO_EXPORT QuicLRUCache { + private: + using HashMapType = + typename quiche::QuicheLinkedHashMap, Hash, Eq>; + public: + // The iterator, if valid, points to std::pair>. + using iterator = typename HashMapType::iterator; + using const_iterator = typename HashMapType::const_iterator; + using reverse_iterator = typename HashMapType::reverse_iterator; + using const_reverse_iterator = typename HashMapType::const_reverse_iterator; + explicit QuicLRUCache(size_t capacity) : capacity_(capacity) {} QuicLRUCache(const QuicLRUCache&) = delete; QuicLRUCache& operator=(const QuicLRUCache&) = delete; + iterator begin() { return cache_.begin(); } + const_iterator begin() const { return cache_.begin(); } + + iterator end() { return cache_.end(); } + const_iterator end() const { return cache_.end(); } + + reverse_iterator rbegin() { return cache_.rbegin(); } + const_reverse_iterator rbegin() const { return cache_.rbegin(); } + + reverse_iterator rend() { return cache_.rend(); } + const_reverse_iterator rend() const { return cache_.rend(); } + // Inserts one unit of |key|, |value| pair to the cache. Cache takes ownership // of inserted |value|. void Insert(const K& key, std::unique_ptr value) { @@ -41,22 +64,21 @@ class QUIC_NO_EXPORT QuicLRUCache { QUICHE_DCHECK_LE(cache_.size(), capacity_); } - // If cache contains an entry for |key|, return a pointer to it. This returned - // value is guaranteed to be valid until Insert or Clear. - // Else return nullptr. - V* Lookup(const K& key) { - auto it = cache_.find(key); - if (it == cache_.end()) { - return nullptr; + iterator Lookup(const K& key) { + auto iter = cache_.find(key); + if (iter == cache_.end()) { + return iter; } - std::unique_ptr value = std::move(it->second); - cache_.erase(it); + std::unique_ptr value = std::move(iter->second); + cache_.erase(iter); auto result = cache_.emplace(key, std::move(value)); QUICHE_DCHECK(result.second); - return result.first->second.get(); + return result.first; } + iterator Erase(iterator iter) { return cache_.erase(iter); } + // Removes all entries from the cache. void Clear() { cache_.clear(); } @@ -67,7 +89,7 @@ class QUIC_NO_EXPORT QuicLRUCache { size_t Size() const { return cache_.size(); } private: - QuicLinkedHashMap> cache_; + quiche::QuicheLinkedHashMap, Hash, Eq> cache_; const size_t capacity_; }; diff --git a/gquiche/quic/core/quic_lru_cache_test.cc b/gquiche/quic/core/quic_lru_cache_test.cc index a64e4cf6..31dc3818 100644 --- a/gquiche/quic/core/quic_lru_cache_test.cc +++ b/gquiche/quic/core/quic_lru_cache_test.cc @@ -18,7 +18,7 @@ struct CachedItem { TEST(QuicLRUCacheTest, InsertAndLookup) { QuicLRUCache cache(5); - EXPECT_EQ(nullptr, cache.Lookup(1)); + EXPECT_EQ(cache.end(), cache.Lookup(1)); EXPECT_EQ(0u, cache.Size()); EXPECT_EQ(5u, cache.MaxSize()); @@ -26,18 +26,23 @@ TEST(QuicLRUCacheTest, InsertAndLookup) { std::unique_ptr item1(new CachedItem(11)); cache.Insert(1, std::move(item1)); EXPECT_EQ(1u, cache.Size()); - EXPECT_EQ(11u, cache.Lookup(1)->value); + EXPECT_EQ(11u, cache.Lookup(1)->second->value); // Check that item 2 overrides item 1. std::unique_ptr item2(new CachedItem(12)); cache.Insert(1, std::move(item2)); EXPECT_EQ(1u, cache.Size()); - EXPECT_EQ(12u, cache.Lookup(1)->value); + EXPECT_EQ(12u, cache.Lookup(1)->second->value); std::unique_ptr item3(new CachedItem(13)); cache.Insert(3, std::move(item3)); EXPECT_EQ(2u, cache.Size()); - EXPECT_EQ(13u, cache.Lookup(3)->value); + auto iter = cache.Lookup(3); + ASSERT_NE(cache.end(), iter); + EXPECT_EQ(13u, iter->second->value); + cache.Erase(iter); + ASSERT_EQ(cache.end(), cache.Lookup(3)); + EXPECT_EQ(1u, cache.Size()); // No memory leakage. cache.Clear(); @@ -56,15 +61,15 @@ TEST(QuicLRUCacheTest, Eviction) { EXPECT_EQ(3u, cache.MaxSize()); // Make sure item 1 is evicted. - EXPECT_EQ(nullptr, cache.Lookup(1)); - EXPECT_EQ(14u, cache.Lookup(4)->value); + EXPECT_EQ(cache.end(), cache.Lookup(1)); + EXPECT_EQ(14u, cache.Lookup(4)->second->value); - EXPECT_EQ(12u, cache.Lookup(2)->value); + EXPECT_EQ(12u, cache.Lookup(2)->second->value); std::unique_ptr item5(new CachedItem(15)); cache.Insert(5, std::move(item5)); // Make sure item 3 is evicted. - EXPECT_EQ(nullptr, cache.Lookup(3)); - EXPECT_EQ(15u, cache.Lookup(5)->value); + EXPECT_EQ(cache.end(), cache.Lookup(3)); + EXPECT_EQ(15u, cache.Lookup(5)->second->value); // No memory leakage. cache.Clear(); diff --git a/gquiche/quic/core/quic_mtu_discovery.h b/gquiche/quic/core/quic_mtu_discovery.h index 973fcb2d..048afc44 100644 --- a/gquiche/quic/core/quic_mtu_discovery.h +++ b/gquiche/quic/core/quic_mtu_discovery.h @@ -27,9 +27,9 @@ static_assert(kMtuDiscoveryAttempts + 8 < 8 * sizeof(QuicPacketNumber), static_assert(kPacketsBetweenMtuProbesBase < (1 << 8), "The initial number of packets between MTU probes is too high"); -// The incresed packet size targeted when doing path MTU discovery. -const QuicByteCount kMtuDiscoveryTargetPacketSizeHigh = 1450; -const QuicByteCount kMtuDiscoveryTargetPacketSizeLow = 1430; +// The increased packet size targeted when doing path MTU discovery. +const QuicByteCount kMtuDiscoveryTargetPacketSizeHigh = 1400; +const QuicByteCount kMtuDiscoveryTargetPacketSizeLow = 1380; static_assert(kMtuDiscoveryTargetPacketSizeLow <= kMaxOutgoingPacketSize, "MTU discovery target is too large"); diff --git a/gquiche/quic/core/quic_network_blackhole_detector.cc b/gquiche/quic/core/quic_network_blackhole_detector.cc index efa7499d..89512df0 100644 --- a/gquiche/quic/core/quic_network_blackhole_detector.cc +++ b/gquiche/quic/core/quic_network_blackhole_detector.cc @@ -10,10 +10,11 @@ namespace quic { namespace { -class AlarmDelegate : public QuicAlarm::Delegate { +class AlarmDelegate : public QuicAlarm::DelegateWithContext { public: - explicit AlarmDelegate(QuicNetworkBlackholeDetector* detector) - : detector_(detector) {} + explicit AlarmDelegate(QuicNetworkBlackholeDetector* detector, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), detector_(detector) {} AlarmDelegate(const AlarmDelegate&) = delete; AlarmDelegate& operator=(const AlarmDelegate&) = delete; @@ -26,12 +27,11 @@ class AlarmDelegate : public QuicAlarm::Delegate { } // namespace QuicNetworkBlackholeDetector::QuicNetworkBlackholeDetector( - Delegate* delegate, - QuicConnectionArena* arena, - QuicAlarmFactory* alarm_factory) + Delegate* delegate, QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, QuicConnectionContext* context) : delegate_(delegate), - alarm_( - alarm_factory->CreateAlarm(arena->New(this), arena)) {} + alarm_(alarm_factory->CreateAlarm( + arena->New(this, context), arena)) {} void QuicNetworkBlackholeDetector::OnAlarm() { QuicTime next_deadline = GetEarliestDeadline(); @@ -64,8 +64,12 @@ void QuicNetworkBlackholeDetector::OnAlarm() { UpdateAlarm(); } -void QuicNetworkBlackholeDetector::StopDetection() { - alarm_->Cancel(); +void QuicNetworkBlackholeDetector::StopDetection(bool permanent) { + if (permanent) { + alarm_->PermanentCancel(); + } else { + alarm_->Cancel(); + } path_degrading_deadline_ = QuicTime::Zero(); blackhole_deadline_ = QuicTime::Zero(); path_mtu_reduction_deadline_ = QuicTime::Zero(); @@ -108,6 +112,12 @@ QuicTime QuicNetworkBlackholeDetector::GetLastDeadline() const { } void QuicNetworkBlackholeDetector::UpdateAlarm() const { + // If called after OnBlackholeDetected(), the alarm may have been permanently + // cancelled and is not safe to be armed again. + if (alarm_->IsPermanentlyCancelled()) { + return; + } + QuicTime next_deadline = GetEarliestDeadline(); QUIC_DVLOG(1) << "Updating alarm. next_deadline:" << next_deadline diff --git a/gquiche/quic/core/quic_network_blackhole_detector.h b/gquiche/quic/core/quic_network_blackhole_detector.h index 254a1ffc..315b938f 100644 --- a/gquiche/quic/core/quic_network_blackhole_detector.h +++ b/gquiche/quic/core/quic_network_blackhole_detector.h @@ -40,12 +40,13 @@ class QUIC_EXPORT_PRIVATE QuicNetworkBlackholeDetector { virtual void OnPathMtuReductionDetected() = 0; }; - QuicNetworkBlackholeDetector(Delegate* delegate, - QuicConnectionArena* arena, - QuicAlarmFactory* alarm_factory); + QuicNetworkBlackholeDetector(Delegate* delegate, QuicConnectionArena* arena, + QuicAlarmFactory* alarm_factory, + QuicConnectionContext* context); - // Called to stop all detections. - void StopDetection(); + // Called to stop all detections. If |permanent|, the alarm will be cancelled + // permanently and future calls to RestartDetection will be no-op. + void StopDetection(bool permanent); // Called to restart path degrading, path mtu reduction and blackhole // detections. Please note, if |blackhole_deadline| is set, it must be the diff --git a/gquiche/quic/core/quic_network_blackhole_detector_test.cc b/gquiche/quic/core/quic_network_blackhole_detector_test.cc index fbfadb17..73b1014f 100644 --- a/gquiche/quic/core/quic_network_blackhole_detector_test.cc +++ b/gquiche/quic/core/quic_network_blackhole_detector_test.cc @@ -33,7 +33,7 @@ const size_t kBlackholeDelayInSeconds = 10; class QuicNetworkBlackholeDetectorTest : public QuicTest { public: QuicNetworkBlackholeDetectorTest() - : detector_(&delegate_, &arena_, &alarm_factory_), + : detector_(&delegate_, &arena_, &alarm_factory_, /*context=*/nullptr), alarm_(static_cast( QuicNetworkBlackholeDetectorPeer::GetAlarm(&detector_))), path_degrading_delay_( @@ -106,7 +106,7 @@ TEST_F(QuicNetworkBlackholeDetectorTest, RestartAndStop) { RestartDetection(); EXPECT_EQ(clock_.Now() + path_degrading_delay_, alarm_->deadline()); - detector_.StopDetection(); + detector_.StopDetection(/*permanent=*/false); EXPECT_FALSE(detector_.IsDetectionInProgress()); } diff --git a/gquiche/quic/core/quic_one_block_arena.h b/gquiche/quic/core/quic_one_block_arena.h index 77c2922e..d86d0ef7 100644 --- a/gquiche/quic/core/quic_one_block_arena.h +++ b/gquiche/quic/core/quic_one_block_arena.h @@ -69,7 +69,7 @@ class QUIC_EXPORT_PRIVATE QuicOneBlockArena { // QuicConnections currently use around 1KB of polymorphic types which would // ordinarily be on the heap. Instead, store them inline in an arena. -using QuicConnectionArena = QuicOneBlockArena<1056>; +using QuicConnectionArena = QuicOneBlockArena<1152>; } // namespace quic diff --git a/gquiche/quic/core/quic_packet_creator.cc b/gquiche/quic/core/quic_packet_creator.cc index 89d6cbcc..0fe637e4 100644 --- a/gquiche/quic/core/quic_packet_creator.cc +++ b/gquiche/quic/core/quic_packet_creator.cc @@ -15,10 +15,12 @@ #include "absl/base/optimization.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "gquiche/quic/core/crypto/crypto_protocol.h" #include "gquiche/quic/core/frames/quic_frame.h" #include "gquiche/quic/core/frames/quic_path_challenge_frame.h" #include "gquiche/quic/core/frames/quic_stream_frame.h" +#include "gquiche/quic/core/quic_chaos_protector.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_constants.h" #include "gquiche/quic/core/quic_data_writer.h" @@ -32,7 +34,7 @@ #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_server_stats.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/print_elements.h" namespace quic { namespace { @@ -107,8 +109,7 @@ QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, delegate) {} QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, - QuicFramer* framer, - QuicRandom* random, + QuicFramer* framer, QuicRandom* random, DelegateInterface* delegate) : delegate_(delegate), debug_delegate_(nullptr), @@ -121,11 +122,7 @@ QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, packet_size_(0), server_connection_id_(server_connection_id), client_connection_id_(EmptyQuicConnectionId()), - packet_(QuicPacketNumber(), - PACKET_1BYTE_PACKET_NUMBER, - nullptr, - 0, - false, + packet_(QuicPacketNumber(), PACKET_1BYTE_PACKET_NUMBER, nullptr, 0, false, false), pending_padding_bytes_(0), needs_full_padding_(false), @@ -133,7 +130,10 @@ QuicPacketCreator::QuicPacketCreator(QuicConnectionId server_connection_id, flusher_attached_(false), fully_pad_crypto_handshake_packets_(true), latched_hard_max_packet_length_(0), - max_datagram_frame_size_(0) { + max_datagram_frame_size_(0), + chaos_protection_enabled_( + GetQuicFlag(FLAGS_quic_enable_chaos_protection) && + framer->perspective() == Perspective::IS_CLIENT) { SetMaxPacketLength(kDefaultMaxPacketSize); if (!framer_->version().UsesTls()) { // QUIC+TLS negotiates the maximum datagram frame size via the @@ -694,6 +694,10 @@ bool QuicPacketCreator::HasPendingFrames() const { return !queued_frames_.empty(); } +std::string QuicPacketCreator::GetPendingFramesInfo() const { + return QuicFramesToString(queued_frames_); +} + bool QuicPacketCreator::HasPendingRetransmittableFrames() const { return !packet_.retransmittable_frames.empty(); } @@ -753,6 +757,35 @@ bool QuicPacketCreator::AddPaddedSavedFrame( return false; } +absl::optional +QuicPacketCreator::MaybeBuildDataPacketWithChaosProtection( + const QuicPacketHeader& header, + char* buffer) { + if (!chaos_protection_enabled_ || + packet_.encryption_level != ENCRYPTION_INITIAL || + !framer_->version().UsesCryptoFrames() || queued_frames_.size() != 2u || + queued_frames_[0].type != CRYPTO_FRAME || + queued_frames_[1].type != PADDING_FRAME || + // Do not perform chaos protection if we do not have a known number of + // padding bytes to work with. + queued_frames_[1].padding_frame.num_padding_bytes <= 0 || + // Chaos protection relies on the framer using a crypto data producer, + // which is always the case in practice. + framer_->data_producer() == nullptr) { + return absl::nullopt; + } + const QuicCryptoFrame& crypto_frame = *queued_frames_[0].crypto_frame; + if (packet_.encryption_level != crypto_frame.level) { + QUIC_BUG(chaos frame level) + << ENDPOINT << packet_.encryption_level << " != " << crypto_frame.level; + return absl::nullopt; + } + QuicChaosProtector chaos_protector( + crypto_frame, queued_frames_[1].padding_frame.num_padding_bytes, + packet_size_, framer_, random_); + return chaos_protector.BuildDataPacket(header, buffer); +} + bool QuicPacketCreator::SerializePacket(QuicOwnedPacketBuffer encrypted_buffer, size_t encrypted_buffer_len) { if (packet_.encrypted_buffer != nullptr) { @@ -802,9 +835,18 @@ bool QuicPacketCreator::SerializePacket(QuicOwnedPacketBuffer encrypted_buffer, QUICHE_DCHECK_GE(max_plaintext_size_, packet_size_) << ENDPOINT; // Use the packet_size_ instead of the buffer size to ensure smaller // packet sizes are properly used. - size_t length = - framer_->BuildDataPacket(header, queued_frames_, encrypted_buffer.buffer, - packet_size_, packet_.encryption_level); + + size_t length; + absl::optional length_with_chaos_protection = + MaybeBuildDataPacketWithChaosProtection(header, encrypted_buffer.buffer); + if (length_with_chaos_protection.has_value()) { + length = length_with_chaos_protection.value(); + } else { + length = framer_->BuildDataPacket(header, queued_frames_, + encrypted_buffer.buffer, packet_size_, + packet_.encryption_level); + } + if (length == 0) { QUIC_BUG(quic_bug_10752_16) << ENDPOINT << "Failed to serialize " @@ -937,7 +979,7 @@ QuicPacketCreator::SerializePathChallengeConnectivityProbingPacket( std::unique_ptr QuicPacketCreator::SerializePathResponseConnectivityProbingPacket( - const QuicCircularDeque& payloads, + const quiche::QuicheCircularDeque& payloads, const bool is_padded) { QUIC_BUG_IF(quic_bug_12398_13, !VersionHasIetfQuicFrames(framer_->transport_version())) @@ -1010,7 +1052,7 @@ size_t QuicPacketCreator::BuildPathResponsePacket( const QuicPacketHeader& header, char* buffer, size_t packet_length, - const QuicCircularDeque& payloads, + const quiche::QuicheCircularDeque& payloads, const bool is_padded, EncryptionLevel level) { if (payloads.empty()) { @@ -1412,8 +1454,14 @@ size_t QuicPacketCreator::ConsumeCryptoData(EncryptionLevel level, // The only pending data in the packet is non-retransmittable frames. I'm // assuming here that they won't occupy so much of the packet that a // CRYPTO frame won't fit. - QUIC_BUG(quic_bug_10752_26) - << ENDPOINT << "Failed to ConsumeCryptoData at level " << level; + const std::string error_message = absl::StrCat( + ENDPOINT, "Failed to ConsumeCryptoData at level ", level, + ", pending_frames: ", GetPendingFramesInfo(), + ", has_soft_max_packet_length: ", HasSoftMaxPacketLength(), + ", max_packet_length: ", max_packet_length_, ", transmission_type: ", + TransmissionTypeToString(next_transmission_type_), + ", packet_number: ", packet_number().ToString()); + QUIC_BUG(quic_bug_10752_26) << error_message; return 0; } total_bytes_consumed += frame.crypto_frame->data_length; @@ -1482,7 +1530,7 @@ bool QuicPacketCreator::FlushAckFrame(const QuicFrames& frames) { QUIC_BUG_IF(quic_bug_12398_18, GetQuicReloadableFlag(quic_single_ack_in_packet2) && !frames.empty() && has_ack()) - << ENDPOINT << "Trying to flush " << frames + << ENDPOINT << "Trying to flush " << quiche::PrintElements(frames) << " when there is ACK queued"; for (const auto& frame : frames) { QUICHE_DCHECK(frame.type == ACK_FRAME || frame.type == STOP_WAITING_FRAME) @@ -1556,14 +1604,14 @@ void QuicPacketCreator::SetTransmissionType(TransmissionType type) { next_transmission_type_ = type; } -MessageStatus QuicPacketCreator::AddMessageFrame(QuicMessageId message_id, - QuicMemSliceSpan message) { +MessageStatus QuicPacketCreator::AddMessageFrame( + QuicMessageId message_id, absl::Span message) { QUIC_BUG_IF(quic_bug_10752_33, !flusher_attached_) << ENDPOINT << "Packet flusher is not attached when " "generator tries to add message frame."; MaybeBundleAckOpportunistically(); - const QuicByteCount message_length = message.total_length(); + const QuicByteCount message_length = MemSliceSpanTotalSize(message); if (message_length > GetCurrentLargestMessagePayload()) { return MESSAGE_STATUS_TOO_LARGE; } @@ -1578,6 +1626,8 @@ MessageStatus QuicPacketCreator::AddMessageFrame(QuicMessageId message_id, delete frame; return MESSAGE_STATUS_INTERNAL_ERROR; } + QUICHE_DCHECK_EQ(MemSliceSpanTotalSize(message), + 0u); // Ensure the old slices are empty. return MESSAGE_STATUS_SUCCESS; } @@ -2084,10 +2134,14 @@ QuicPacketCreator::ScopedPeerAddressContext::ScopedPeerAddressContext( "initialized."; creator_->SetDefaultPeerAddress(address); if (update_connection_id_) { - QUICHE_DCHECK(address != old_peer_address_ || - ((client_connection_id == old_client_connection_id_) && - (server_connection_id == old_server_connection_id_))) - << ENDPOINT2; + // Flush current packet if connection ID length changes. + if (address == old_peer_address_ && + ((client_connection_id.length() != + old_client_connection_id_.length()) || + (server_connection_id.length() != + old_server_connection_id_.length()))) { + creator_->FlushCurrentPacket(); + } creator_->SetClientConnectionId(client_connection_id); creator_->SetServerConnectionId(server_connection_id); } diff --git a/gquiche/quic/core/quic_packet_creator.h b/gquiche/quic/core/quic_packet_creator.h index 8be7cd3c..cb5e0e0d 100644 --- a/gquiche/quic/core/quic_packet_creator.h +++ b/gquiche/quic/core/quic_packet_creator.h @@ -22,14 +22,15 @@ #include "absl/base/attributes.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "gquiche/quic/core/frames/quic_stream_frame.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_coalesced_packet.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_framer.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { namespace test { @@ -204,6 +205,10 @@ class QUIC_EXPORT_PRIVATE QuicPacketCreator { // Returns true if there are frames pending to be serialized. bool HasPendingFrames() const; + // TODO(haoyuewang) Remove this debug utility. + // Returns the information of pending frames as a string. + std::string GetPendingFramesInfo() const; + // Returns true if there are retransmittable frames pending to be serialized. bool HasPendingRetransmittableFrames() const; @@ -258,7 +263,7 @@ class QUIC_EXPORT_PRIVATE QuicPacketCreator { // |payloads| is cleared. std::unique_ptr SerializePathResponseConnectivityProbingPacket( - const QuicCircularDeque& payloads, + const quiche::QuicheCircularDeque& payloads, const bool is_padded); // Add PATH_RESPONSE to current packet, flush before or afterwards if needed. @@ -308,6 +313,11 @@ class QUIC_EXPORT_PRIVATE QuicPacketCreator { void set_encryption_level(EncryptionLevel level); EncryptionLevel encryption_level() { return packet_.encryption_level; } + // Sets whether initial packets are protected with chaos. + void set_chaos_protection_enabled(bool chaos_protection_enabled) { + chaos_protection_enabled_ = chaos_protection_enabled; + } + // packet number of the last created packet, or 0 if no packets have been // created. QuicPacketNumber packet_number() const { return packet_.packet_number; } @@ -411,7 +421,7 @@ class QUIC_EXPORT_PRIVATE QuicPacketCreator { // Tries to add a message frame containing |message| and returns the status. MessageStatus AddMessageFrame(QuicMessageId message_id, - QuicMemSliceSpan message); + absl::Span message); // Returns the largest payload that will fit into a single MESSAGE frame. QuicPacketLength GetCurrentLargestMessagePayload() const; @@ -465,7 +475,7 @@ class QUIC_EXPORT_PRIVATE QuicPacketCreator { const QuicPacketHeader& header, char* buffer, size_t packet_length, - const QuicCircularDeque& payloads, + const quiche::QuicheCircularDeque& payloads, const bool is_padded, EncryptionLevel level); @@ -508,6 +518,13 @@ class QUIC_EXPORT_PRIVATE QuicPacketCreator { QuicPacketCreator* creator_; // Unowned. }; + // Attempts to build a data packet with chaos protection. If this packet isn't + // supposed to be protected or if serialization fails then absl::nullopt is + // returned. Otherwise returns the serialized length. + absl::optional MaybeBuildDataPacketWithChaosProtection( + const QuicPacketHeader& header, + char* buffer); + // Creates a stream frame which fits into the current open packet. If // |data_size| is 0 and fin is true, the expected behavior is to consume // the fin. @@ -692,6 +709,9 @@ class QUIC_EXPORT_PRIVATE QuicPacketCreator { // accept. There is no limit for QUIC_CRYPTO connections, but QUIC+TLS // negotiates this during the handshake. QuicByteCount max_datagram_frame_size_; + + // Whether to attempt protecting initial packets with chaos. + bool chaos_protection_enabled_; }; } // namespace quic diff --git a/gquiche/quic/core/quic_packet_creator_test.cc b/gquiche/quic/core/quic_packet_creator_test.cc index dd8d9528..45842f28 100644 --- a/gquiche/quic/core/quic_packet_creator_test.cc +++ b/gquiche/quic/core/quic_packet_creator_test.cc @@ -35,13 +35,14 @@ #include "gquiche/quic/test_tools/simple_quic_framer.h" #include "gquiche/common/test_tools/quiche_test_utils.h" -using testing::_; -using testing::DoAll; -using testing::InSequence; -using testing::Invoke; -using testing::Return; -using testing::SaveArg; -using testing::StrictMock; +using ::testing::_; +using ::testing::AtLeast; +using ::testing::DoAll; +using ::testing::InSequence; +using ::testing::Invoke; +using ::testing::Return; +using ::testing::SaveArg; +using ::testing::StrictMock; namespace quic { namespace test { @@ -270,6 +271,8 @@ class QuicPacketCreatorTest : public QuicTestWithParam { n * 2; } + void TestChaosProtection(bool enabled); + static constexpr QuicStreamOffset kOffset = 0u; char buffer_[kMaxOutgoingPacketSize]; @@ -420,6 +423,8 @@ TEST_P(QuicPacketCreatorTest, ConsumeDataFinOnly) { EXPECT_EQ(0u, consumed); CheckStreamFrame(frame, stream_id, std::string(), 0u, true); EXPECT_TRUE(creator_.HasPendingFrames()); + EXPECT_TRUE(absl::StartsWith(creator_.GetPendingFramesInfo(), + "type { STREAM_FRAME }")); } TEST_P(QuicPacketCreatorTest, CreateAllFreeBytesForStreamFrames) { @@ -752,7 +757,7 @@ TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket1ResponseUnpadded) { }; // clang-format on std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); size_t length = creator_.BuildPathResponsePacket( header, buffer.get(), ABSL_ARRAYSIZE(packet), payloads, @@ -799,7 +804,7 @@ TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket1ResponsePadded) { }; // clang-format on std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); size_t length = creator_.BuildPathResponsePacket( header, buffer.get(), ABSL_ARRAYSIZE(packet), payloads, @@ -849,7 +854,7 @@ TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket3ResponsesUnpadded) { // clang-format on std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); payloads.push_back(payload1); payloads.push_back(payload2); @@ -903,7 +908,7 @@ TEST_P(QuicPacketCreatorTest, BuildPathResponsePacket3ResponsesPadded) { // clang-format on std::unique_ptr buffer(new char[kMaxOutgoingPacketSize]); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); payloads.push_back(payload1); payloads.push_back(payload2); @@ -988,7 +993,7 @@ TEST_P(QuicPacketCreatorTest, SerializePathResponseProbePacket1PayloadPadded) { creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); std::unique_ptr encrypted( @@ -1018,7 +1023,7 @@ TEST_P(QuicPacketCreatorTest, creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); std::unique_ptr encrypted( @@ -1048,7 +1053,7 @@ TEST_P(QuicPacketCreatorTest, SerializePathResponseProbePacket2PayloadsPadded) { creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); payloads.push_back(payload1); @@ -1081,7 +1086,7 @@ TEST_P(QuicPacketCreatorTest, creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); payloads.push_back(payload1); @@ -1114,7 +1119,7 @@ TEST_P(QuicPacketCreatorTest, SerializePathResponseProbePacket3PayloadsPadded) { creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); payloads.push_back(payload1); payloads.push_back(payload2); @@ -1150,7 +1155,7 @@ TEST_P(QuicPacketCreatorTest, creator_.set_encryption_level(ENCRYPTION_FORWARD_SECURE); - QuicCircularDeque payloads; + quiche::QuicheCircularDeque payloads; payloads.push_back(payload0); payloads.push_back(payload1); payloads.push_back(payload2); @@ -1312,7 +1317,7 @@ TEST_P(QuicPacketCreatorTest, SerializeFrameShortData) { if (!GetParam().version_serialization) { creator_.StopSendingVersion(); } - std::string data("a"); + std::string data("Hello World!"); if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { QuicStreamFrame stream_frame( QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), @@ -1339,15 +1344,51 @@ TEST_P(QuicPacketCreatorTest, SerializeFrameShortData) { } else { EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); } - if (client_framer_.version().HasHeaderProtection()) { - EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); - } EXPECT_CALL(framer_visitor_, OnPacketComplete()); } ProcessPacket(serialized); EXPECT_EQ(GetParam().version_serialization, header.version_flag); } +void QuicPacketCreatorTest::TestChaosProtection(bool enabled) { + if (!GetParam().version.UsesCryptoFrames()) { + return; + } + MockRandom mock_random(2); + QuicPacketCreatorPeer::SetRandom(&creator_, &mock_random); + creator_.set_chaos_protection_enabled(enabled); + std::string data("ChAoS_ThEoRy!"); + producer_.SaveCryptoData(ENCRYPTION_INITIAL, 0, data); + frames_.push_back( + QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, 0, data.length()))); + frames_.push_back(QuicFrame(QuicPaddingFrame(33))); + SerializedPacket serialized = SerializeAllFrames(frames_); + EXPECT_CALL(framer_visitor_, OnPacket()); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedPublicHeader(_)); + EXPECT_CALL(framer_visitor_, OnUnauthenticatedHeader(_)); + EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); + EXPECT_CALL(framer_visitor_, OnPacketHeader(_)); + if (enabled) { + EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)).Times(AtLeast(2)); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)).Times(AtLeast(2)); + EXPECT_CALL(framer_visitor_, OnPingFrame(_)).Times(AtLeast(1)); + } else { + EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)).Times(1); + EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)).Times(1); + EXPECT_CALL(framer_visitor_, OnPingFrame(_)).Times(0); + } + EXPECT_CALL(framer_visitor_, OnPacketComplete()); + ProcessPacket(serialized); +} + +TEST_P(QuicPacketCreatorTest, ChaosProtectionEnabled) { + TestChaosProtection(/*enabled=*/true); +} + +TEST_P(QuicPacketCreatorTest, ChaosProtectionDisabled) { + TestChaosProtection(/*enabled=*/false); +} + TEST_P(QuicPacketCreatorTest, ConsumeDataLargerThanOneStreamFrame) { if (!GetParam().version_serialization) { creator_.StopSendingVersion(); @@ -1728,25 +1769,24 @@ TEST_P(QuicPacketCreatorTest, AddMessageFrame) { .Times(3) .WillRepeatedly( Invoke(this, &QuicPacketCreatorTest::ClearSerializedPacketForTests)); - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); // Verify that there is enough room for the largest message payload. EXPECT_TRUE(creator_.HasRoomForMessageFrame( creator_.GetCurrentLargestMessagePayload())); - std::string message(creator_.GetCurrentLargestMessagePayload(), 'a'); + std::string large_message(creator_.GetCurrentLargestMessagePayload(), 'a'); QuicMessageFrame* message_frame = - new QuicMessageFrame(1, MakeSpan(&allocator_, message, &storage)); + new QuicMessageFrame(1, MemSliceFromString(large_message)); EXPECT_TRUE(creator_.AddFrame(QuicFrame(message_frame), NOT_RETRANSMISSION)); EXPECT_TRUE(creator_.HasPendingFrames()); creator_.FlushCurrentPacket(); QuicMessageFrame* frame2 = - new QuicMessageFrame(2, MakeSpan(&allocator_, "message", &storage)); + new QuicMessageFrame(2, MemSliceFromString("message")); EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame2), NOT_RETRANSMISSION)); EXPECT_TRUE(creator_.HasPendingFrames()); // Verify if a new frame is added, 1 byte message length will be added. EXPECT_EQ(1u, creator_.ExpansionOnNewFrame()); QuicMessageFrame* frame3 = - new QuicMessageFrame(3, MakeSpan(&allocator_, "message2", &storage)); + new QuicMessageFrame(3, MemSliceFromString("message2")); EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame3), NOT_RETRANSMISSION)); EXPECT_EQ(1u, creator_.ExpansionOnNewFrame()); creator_.FlushCurrentPacket(); @@ -1759,14 +1799,14 @@ TEST_P(QuicPacketCreatorTest, AddMessageFrame) { stream_id, &iov_, 1u, iov_.iov_len, 0u, 0u, false, false, NOT_RETRANSMISSION, &frame)); QuicMessageFrame* frame4 = - new QuicMessageFrame(4, MakeSpan(&allocator_, "message", &storage)); + new QuicMessageFrame(4, MemSliceFromString("message")); EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame4), NOT_RETRANSMISSION)); EXPECT_TRUE(creator_.HasPendingFrames()); // Verify there is not enough room for largest payload. EXPECT_FALSE(creator_.HasRoomForMessageFrame( creator_.GetCurrentLargestMessagePayload())); // Add largest message will causes the flush of the stream frame. - QuicMessageFrame frame5(5, MakeSpan(&allocator_, message, &storage)); + QuicMessageFrame frame5(5, MemSliceFromString(large_message)); EXPECT_FALSE(creator_.AddFrame(QuicFrame(&frame5), NOT_RETRANSMISSION)); EXPECT_FALSE(creator_.HasPendingFrames()); } @@ -1779,8 +1819,6 @@ TEST_P(QuicPacketCreatorTest, MessageFrameConsumption) { creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); } std::string message_data(kDefaultMaxPacketSize, 'a'); - absl::string_view message_buffer(message_data); - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); // Test all possible encryption levels of message frames. for (EncryptionLevel level : {ENCRYPTION_ZERO_RTT, ENCRYPTION_FORWARD_SECURE}) { @@ -1789,10 +1827,9 @@ TEST_P(QuicPacketCreatorTest, MessageFrameConsumption) { for (size_t message_size = 0; message_size <= creator_.GetCurrentLargestMessagePayload(); ++message_size) { - QuicMessageFrame* frame = new QuicMessageFrame( - 0, MakeSpan(&allocator_, - absl::string_view(message_buffer.data(), message_size), - &storage)); + QuicMessageFrame* frame = + new QuicMessageFrame(0, MemSliceFromString(absl::string_view( + message_data.data(), message_size))); EXPECT_TRUE(creator_.AddFrame(QuicFrame(frame), NOT_RETRANSMISSION)); EXPECT_TRUE(creator_.HasPendingFrames()); @@ -1826,7 +1863,7 @@ TEST_P(QuicPacketCreatorTest, GetGuaranteedLargestMessagePayload) { if (version.UsesTls()) { creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); } - QuicPacketLength expected_largest_payload = 1319; + QuicPacketLength expected_largest_payload = 1219; if (version.HasLongHeaderLengths()) { expected_largest_payload -= 2; } @@ -1877,7 +1914,7 @@ TEST_P(QuicPacketCreatorTest, GetCurrentLargestMessagePayload) { if (version.UsesTls()) { creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); } - QuicPacketLength expected_largest_payload = 1319; + QuicPacketLength expected_largest_payload = 1219; if (version.SendsVariableLengthPacketNumberInLongHeader()) { expected_largest_payload += 3; } @@ -1955,17 +1992,7 @@ TEST_P(QuicPacketCreatorTest, RetryToken) { creator_.SetRetryToken( std::string(retry_token_bytes, sizeof(retry_token_bytes))); - std::string data("a"); - if (!QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { - QuicStreamFrame stream_frame( - QuicUtils::GetCryptoStreamId(client_framer_.transport_version()), - /*fin=*/false, 0u, absl::string_view()); - frames_.push_back(QuicFrame(stream_frame)); - } else { - producer_.SaveCryptoData(ENCRYPTION_INITIAL, 0, data); - frames_.push_back( - QuicFrame(new QuicCryptoFrame(ENCRYPTION_INITIAL, 0, data.length()))); - } + frames_.push_back(QuicFrame(QuicPingFrame())); SerializedPacket serialized = SerializeAllFrames(frames_); QuicPacketHeader header; @@ -1977,11 +2004,7 @@ TEST_P(QuicPacketCreatorTest, RetryToken) { EXPECT_CALL(framer_visitor_, OnDecryptedPacket(_, _)); EXPECT_CALL(framer_visitor_, OnPacketHeader(_)) .WillOnce(DoAll(SaveArg<0>(&header), Return(true))); - if (QuicVersionUsesCryptoFrames(client_framer_.transport_version())) { - EXPECT_CALL(framer_visitor_, OnCryptoFrame(_)); - } else { - EXPECT_CALL(framer_visitor_, OnStreamFrame(_)); - } + EXPECT_CALL(framer_visitor_, OnPingFrame(_)); if (client_framer_.version().HasHeaderProtection()) { EXPECT_CALL(framer_visitor_, OnPaddingFrame(_)); } @@ -2425,12 +2448,13 @@ class MultiplePacketsTestPacketCreator : public QuicPacketCreator { } MessageStatus AddMessageFrame(QuicMessageId message_id, - QuicMemSliceSpan message) { + QuicMemSlice message) { if (!has_ack() && delegate_->ShouldGeneratePacket(NO_RETRANSMITTABLE_DATA, NOT_HANDSHAKE)) { EXPECT_CALL(*delegate_, MaybeBundleAckOpportunistically()).Times(1); } - return QuicPacketCreator::AddMessageFrame(message_id, message); + return QuicPacketCreator::AddMessageFrame(message_id, + absl::MakeSpan(&message, 1)); } size_t ConsumeCryptoData(EncryptionLevel level, @@ -3225,7 +3249,7 @@ TEST_F(QuicPacketCreatorMultiplePacketsTest, PacketTransmissionType) { // The first ConsumeData will fill the packet without flush. creator_.SetTransmissionType(LOSS_RETRANSMISSION); - size_t data_len = 1324; + size_t data_len = 1224; CreateData(data_len); QuicStreamId stream1_id = QuicUtils::GetFirstBidirectionalStreamId( framer_.transport_version(), Perspective::IS_CLIENT); @@ -3767,7 +3791,6 @@ TEST_F(QuicPacketCreatorMultiplePacketsTest, AddMessageFrame) { if (framer_.version().UsesTls()) { creator_.SetMaxDatagramFrameSize(kMaxAcceptedDatagramFrameSize); } - quic::QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); delegate_.SetCanWriteAnything(); EXPECT_CALL(delegate_, OnSerializedPacket(_)) .WillOnce( @@ -3777,30 +3800,23 @@ TEST_F(QuicPacketCreatorMultiplePacketsTest, AddMessageFrame) { creator_.ConsumeData(QuicUtils::GetFirstBidirectionalStreamId( framer_.transport_version(), Perspective::IS_CLIENT), &iov_, 1u, iov_.iov_len, 0, FIN); - EXPECT_EQ( - MESSAGE_STATUS_SUCCESS, - creator_.AddMessageFrame(1, MakeSpan(&allocator_, "message", &storage))); + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, + creator_.AddMessageFrame(1, MemSliceFromString("message"))); EXPECT_TRUE(creator_.HasPendingFrames()); EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); // Add a message which causes the flush of current packet. - EXPECT_EQ( - MESSAGE_STATUS_SUCCESS, - creator_.AddMessageFrame( - 2, - MakeSpan(&allocator_, - std::string(creator_.GetCurrentLargestMessagePayload(), 'a'), - &storage))); + EXPECT_EQ(MESSAGE_STATUS_SUCCESS, + creator_.AddMessageFrame( + 2, MemSliceFromString(std::string( + creator_.GetCurrentLargestMessagePayload(), 'a')))); EXPECT_TRUE(creator_.HasPendingRetransmittableFrames()); // Failed to send messages which cannot fit into one packet. - EXPECT_EQ( - MESSAGE_STATUS_TOO_LARGE, - creator_.AddMessageFrame( - 3, MakeSpan(&allocator_, - std::string( - creator_.GetCurrentLargestMessagePayload() + 10, 'a'), - &storage))); + EXPECT_EQ(MESSAGE_STATUS_TOO_LARGE, + creator_.AddMessageFrame( + 3, MemSliceFromString(std::string( + creator_.GetCurrentLargestMessagePayload() + 10, 'a')))); } TEST_F(QuicPacketCreatorMultiplePacketsTest, ConnectionId) { diff --git a/gquiche/quic/core/quic_packets.cc b/gquiche/quic/core/quic_packets.cc index 983ff7cd..05e58976 100644 --- a/gquiche/quic/core/quic_packets.cc +++ b/gquiche/quic/core/quic_packets.cc @@ -15,7 +15,6 @@ #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { diff --git a/gquiche/quic/core/quic_packets.h b/gquiche/quic/core/quic_packets.h index c121325a..89ff326c 100644 --- a/gquiche/quic/core/quic_packets.h +++ b/gquiche/quic/core/quic_packets.h @@ -455,6 +455,7 @@ struct QUIC_EXPORT_PRIVATE ReceivedPacketInfo { ParsedQuicVersion version; QuicConnectionId destination_connection_id; QuicConnectionId source_connection_id; + absl::optional retry_token; }; } // namespace quic diff --git a/gquiche/quic/core/quic_path_validator.cc b/gquiche/quic/core/quic_path_validator.cc index 7d329a8e..a19df765 100644 --- a/gquiche/quic/core/quic_path_validator.cc +++ b/gquiche/quic/core/quic_path_validator.cc @@ -10,10 +10,12 @@ namespace quic { -class RetryAlarmDelegate : public QuicAlarm::Delegate { +class RetryAlarmDelegate : public QuicAlarm::DelegateWithContext { public: - explicit RetryAlarmDelegate(QuicPathValidator* path_validator) - : path_validator_(path_validator) {} + explicit RetryAlarmDelegate(QuicPathValidator* path_validator, + QuicConnectionContext* context) + : QuicAlarm::DelegateWithContext(context), + path_validator_(path_validator) {} RetryAlarmDelegate(const RetryAlarmDelegate&) = delete; RetryAlarmDelegate& operator=(const RetryAlarmDelegate&) = delete; @@ -32,12 +34,12 @@ std::ostream& operator<<(std::ostream& os, QuicPathValidator::QuicPathValidator(QuicAlarmFactory* alarm_factory, QuicConnectionArena* arena, SendDelegate* send_delegate, - QuicRandom* random) + QuicRandom* random, + QuicConnectionContext* context) : send_delegate_(send_delegate), random_(random), - retry_timer_( - alarm_factory->CreateAlarm(arena->New(this), - arena)), + retry_timer_(alarm_factory->CreateAlarm( + arena->New(this, context), arena)), retry_count_(0u) {} void QuicPathValidator::OnPathResponse(const QuicPathFrameBuffer& probing_data, diff --git a/gquiche/quic/core/quic_path_validator.h b/gquiche/quic/core/quic_path_validator.h index 625014b3..a4c3e51c 100644 --- a/gquiche/quic/core/quic_path_validator.h +++ b/gquiche/quic/core/quic_path_validator.h @@ -7,11 +7,13 @@ #include +#include "absl/container/inlined_vector.h" #include "gquiche/quic/core/crypto/quic_random.h" #include "gquiche/quic/core/quic_alarm.h" #include "gquiche/quic/core/quic_alarm_factory.h" #include "gquiche/quic/core/quic_arena_scoped_ptr.h" #include "gquiche/quic/core/quic_clock.h" +#include "gquiche/quic/core/quic_connection_context.h" #include "gquiche/quic/core/quic_one_block_arena.h" #include "gquiche/quic/core/quic_packet_writer.h" #include "gquiche/quic/core/quic_types.h" @@ -106,10 +108,9 @@ class QUIC_EXPORT_PRIVATE QuicPathValidator { std::unique_ptr context) = 0; }; - QuicPathValidator(QuicAlarmFactory* alarm_factory, - QuicConnectionArena* arena, - SendDelegate* delegate, - QuicRandom* random); + QuicPathValidator(QuicAlarmFactory* alarm_factory, QuicConnectionArena* arena, + SendDelegate* delegate, QuicRandom* random, + QuicConnectionContext* context); // Send PATH_CHALLENGE and start the retry timer. void StartPathValidation(std::unique_ptr context, @@ -145,7 +146,7 @@ class QUIC_EXPORT_PRIVATE QuicPathValidator { void ResetPathValidation(); // Has at most 3 entries due to validation timeout. - QuicInlinedVector probing_data_; + absl::InlinedVector probing_data_; SendDelegate* send_delegate_; QuicRandom* random_; std::unique_ptr path_context_; diff --git a/gquiche/quic/core/quic_path_validator_test.cc b/gquiche/quic/core/quic_path_validator_test.cc index 678b21eb..65a335c4 100644 --- a/gquiche/quic/core/quic_path_validator_test.cc +++ b/gquiche/quic/core/quic_path_validator_test.cc @@ -48,11 +48,10 @@ class MockSendDelegate : public QuicPathValidator::SendDelegate { class QuicPathValidatorTest : public QuicTest { public: QuicPathValidatorTest() - : path_validator_(&alarm_factory_, &arena_, &send_delegate_, &random_), - context_(new MockQuicPathValidationContext(self_address_, - peer_address_, - effective_peer_address_, - &writer_)), + : path_validator_(&alarm_factory_, &arena_, &send_delegate_, &random_, + /*context=*/nullptr), + context_(new MockQuicPathValidationContext( + self_address_, peer_address_, effective_peer_address_, &writer_)), result_delegate_( new testing::StrictMock()) { clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); diff --git a/gquiche/quic/core/quic_protocol_flags_list.h b/gquiche/quic/core/quic_protocol_flags_list.h index d2a334ca..c443b559 100644 --- a/gquiche/quic/core/quic_protocol_flags_list.h +++ b/gquiche/quic/core/quic_protocol_flags_list.h @@ -44,6 +44,26 @@ QUIC_PROTOCOL_FLAG(int64_t, "Time period for which a given connection_id should live in " "the time-wait state.") +// This number is relatively conservative. For example, there are at most 1K +// queued stateless resets, which consume 1K * 21B = 21KB. +QUIC_PROTOCOL_FLAG( + uint64_t, quic_time_wait_list_max_pending_packets, 1024, + "Upper limit of pending packets in time wait list when writer is blocked.") + +// Stop sending a reset if the recorded number of addresses that server has +// recently sent stateless reset to exceeds this limit. +QUIC_PROTOCOL_FLAG(uint64_t, quic_max_recent_stateless_reset_addresses, 1024, + "Max number of recorded recent reset addresses.") + +// After this timeout, recent reset addresses will be cleared. +// FLAGS_quic_max_recent_stateless_reset_addresses * (1000ms / +// FLAGS_quic_recent_stateless_reset_addresses_lifetime_ms) is roughly the max +// reset per second. For example, 1024 * (1000ms / 1000ms) = 1K reset per +// second. +QUIC_PROTOCOL_FLAG( + uint64_t, quic_recent_stateless_reset_addresses_lifetime_ms, 1000, + "Max time that a client address lives in recent reset addresses set.") + QUIC_PROTOCOL_FLAG(double, quic_bbr_cwnd_gain, 2.0f, @@ -77,6 +97,11 @@ QUIC_PROTOCOL_FLAG( "Congestion window fraction that the pacing sender allows in bursts " "during pacing.") +QUIC_PROTOCOL_FLAG( + int32_t, quic_lumpy_pacing_min_bandwidth_kbps, 1200, + "The minimum estimated client bandwidth below which the pacing sender will " + "not allow bursts.") + QUIC_PROTOCOL_FLAG(int32_t, quic_max_pace_time_into_future_ms, 10, @@ -103,6 +128,10 @@ QUIC_PROTOCOL_FLAG(bool, true, "If true, use random greased settings and frames.") +QUIC_PROTOCOL_FLAG( + bool, quic_enable_chaos_protection, true, + "If true, use chaos protection to randomize client initials.") + QUIC_PROTOCOL_FLAG(int64_t, quic_max_tracked_packet_count, 10000, @@ -248,4 +277,13 @@ QUIC_PROTOCOL_FLAG( true, "If true, QUIC QPACK decoder includes 32-bytes overheader per entry while " "comparing request/response header size against its upper limit.") + +QUIC_PROTOCOL_FLAG( + bool, + quic_reject_retry_token_in_initial_packet, + false, + "If true, always reject retry_token received in INITIAL packets") + +QUIC_PROTOCOL_FLAG(bool, quic_use_lower_server_response_mtu_for_test, false, + "If true, cap server response packet size at 1250.") #endif diff --git a/gquiche/quic/core/quic_received_packet_manager.cc b/gquiche/quic/core/quic_received_packet_manager.cc index 06f6ddba..a3c3ac8d 100644 --- a/gquiche/quic/core/quic_received_packet_manager.cc +++ b/gquiche/quic/core/quic_received_packet_manager.cc @@ -205,11 +205,7 @@ QuicTime::Delta QuicReceivedPacketManager::GetMaxAckDelay( // before sending an ack. QuicTime::Delta ack_delay = std::min( local_max_ack_delay_, rtt_stats.min_rtt() * ack_decimation_delay_); - if (GetQuicReloadableFlag(quic_ack_delay_alarm_granularity)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_ack_delay_alarm_granularity); - ack_delay = std::max(ack_delay, kAlarmGranularity); - } - return ack_delay; + return std::max(ack_delay, kAlarmGranularity); } void QuicReceivedPacketManager::MaybeUpdateAckFrequency( diff --git a/gquiche/quic/core/quic_received_packet_manager_test.cc b/gquiche/quic/core/quic_received_packet_manager_test.cc index 93905f6b..800ad9fd 100644 --- a/gquiche/quic/core/quic_received_packet_manager_test.cc +++ b/gquiche/quic/core/quic_received_packet_manager_test.cc @@ -397,9 +397,6 @@ TEST_F(QuicReceivedPacketManagerTest, SendDelayedAckDecimation) { } TEST_F(QuicReceivedPacketManagerTest, SendDelayedAckDecimationMin1ms) { - if (!GetQuicReloadableFlag(quic_ack_delay_alarm_granularity)) { - return; - } EXPECT_FALSE(HasPendingAck()); // Seed the min_rtt with a kAlarmGranularity signal. rtt_stats_.UpdateRtt(kAlarmGranularity, QuicTime::Delta::Zero(), @@ -437,7 +434,6 @@ TEST_F(QuicReceivedPacketManagerTest, EXPECT_FALSE(HasPendingAck()); QuicConfig config; QuicTagVector connection_options; - connection_options.push_back(kACKD); // No limit on the number of packets received before sending an ack. connection_options.push_back(kAKDU); config.SetConnectionOptionsToSend(connection_options); diff --git a/gquiche/quic/core/quic_sent_packet_manager.cc b/gquiche/quic/core/quic_sent_packet_manager.cc index d16d0143..1bca8bd0 100644 --- a/gquiche/quic/core/quic_sent_packet_manager.cc +++ b/gquiche/quic/core/quic_sent_packet_manager.cc @@ -24,7 +24,7 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" +#include "gquiche/common/print_elements.h" namespace quic { @@ -117,7 +117,8 @@ QuicSentPacketManager::QuicSentPacketManager( use_standard_deviation_for_pto_(false), pto_multiplier_without_rtt_samples_(3), num_ptos_for_path_degrading_(0), - ignore_pings_(false) { + ignore_pings_(false), + ignore_ack_delay_(false) { SetSendAlgorithm(congestion_control_type); if (pto_enabled_) { QUIC_RELOADABLE_FLAG_COUNT_N(quic_default_on_pto, 1, 2); @@ -158,7 +159,7 @@ void QuicSentPacketManager::SetFromConfig(const QuicConfig& config) { } } if (config.HasClientSentConnectionOption(kMAD0, perspective)) { - rtt_stats_.set_ignore_max_ack_delay(true); + ignore_ack_delay_ = true; } if (config.HasClientSentConnectionOption(kMAD2, perspective)) { // Set the minimum to the alarm granularity. @@ -208,8 +209,12 @@ void QuicSentPacketManager::SetFromConfig(const QuicConfig& config) { QUIC_CODE_COUNT(two_aggressive_ptos); num_tlp_timeout_ptos_ = 2; } + if (GetQuicReloadableFlag(quic_deprecate_tlpr)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_deprecate_tlpr, 2, 2); + } if (config.HasClientSentConnectionOption(kPLE1, perspective) || - config.HasClientSentConnectionOption(kTLPR, perspective)) { + (config.HasClientSentConnectionOption(kTLPR, perspective) && + !GetQuicReloadableFlag(quic_deprecate_tlpr))) { first_pto_srtt_multiplier_ = 0.5; } else if (config.HasClientSentConnectionOption(kPLE2, perspective)) { first_pto_srtt_multiplier_ = 1.5; @@ -221,6 +226,9 @@ void QuicSentPacketManager::SetFromConfig(const QuicConfig& config) { use_standard_deviation_for_pto_ = true; rtt_stats_.EnableStandardDeviationCalculation(); } + if (config.HasClientRequestedIndependentOption(kPDP1, perspective)) { + num_ptos_for_path_degrading_ = 1; + } if (config.HasClientRequestedIndependentOption(kPDP2, perspective)) { num_ptos_for_path_degrading_ = 2; } @@ -256,18 +264,22 @@ void QuicSentPacketManager::SetFromConfig(const QuicConfig& config) { // Initial window. if (GetQuicReloadableFlag(quic_unified_iw_options)) { if (config.HasClientRequestedIndependentOption(kIW03, perspective)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_unified_iw_options, 1, 4); initial_congestion_window_ = 3; send_algorithm_->SetInitialCongestionWindowInPackets(3); } if (config.HasClientRequestedIndependentOption(kIW10, perspective)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_unified_iw_options, 2, 4); initial_congestion_window_ = 10; send_algorithm_->SetInitialCongestionWindowInPackets(10); } if (config.HasClientRequestedIndependentOption(kIW20, perspective)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_unified_iw_options, 3, 4); initial_congestion_window_ = 20; send_algorithm_->SetInitialCongestionWindowInPackets(20); } if (config.HasClientRequestedIndependentOption(kIW50, perspective)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_unified_iw_options, 4, 4); initial_congestion_window_ = 50; send_algorithm_->SetInitialCongestionWindowInPackets(50); } @@ -292,10 +304,15 @@ void QuicSentPacketManager::SetFromConfig(const QuicConfig& config) { if (config.HasClientSentConnectionOption(k1RTO, perspective)) { max_rto_packets_ = 1; } - if (config.HasClientSentConnectionOption(kTLPR, perspective)) { + if (GetQuicReloadableFlag(quic_deprecate_tlpr)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_deprecate_tlpr, 1, 2); + } + if (config.HasClientSentConnectionOption(kTLPR, perspective) && + !GetQuicReloadableFlag(quic_deprecate_tlpr)) { enable_half_rtt_tail_loss_probe_ = true; } - if (config.HasClientRequestedIndependentOption(kTLPR, perspective)) { + if (config.HasClientRequestedIndependentOption(kTLPR, perspective) && + !GetQuicReloadableFlag(quic_deprecate_tlpr)) { enable_half_rtt_tail_loss_probe_ = true; } if (config.HasClientSentConnectionOption(kNRTO, perspective)) { @@ -1526,8 +1543,18 @@ void QuicSentPacketManager::OnAckFrameStart(QuicPacketNumber largest_acked, QuicTime ack_receive_time) { QUICHE_DCHECK(packets_acked_.empty()); QUICHE_DCHECK_LE(largest_acked, unacked_packets_.largest_sent_packet()); - if (ack_delay_time > peer_max_ack_delay()) { - ack_delay_time = peer_max_ack_delay(); + if (GetQuicReloadableFlag(quic_ignore_peer_max_ack_delay_during_handshake) && + supports_multiple_packet_number_spaces() && !handshake_finished_) { + QUIC_RELOADABLE_FLAG_COUNT(quic_ignore_peer_max_ack_delay_during_handshake); + // Ignore peer_max_ack_delay and use received ack_delay during + // handshake. + } else { + if (ack_delay_time > peer_max_ack_delay()) { + ack_delay_time = peer_max_ack_delay(); + } + if (ignore_ack_delay_) { + ack_delay_time = QuicTime::Delta::Zero(); + } } rtt_updated_ = MaybeUpdateRTT(largest_acked, ack_delay_time, ack_receive_time); @@ -1601,7 +1628,7 @@ AckResult QuicSentPacketManager::OnAckFrameEnd( << acked_packet.packet_number << ", last_ack_frame_: " << last_ack_frame_ << ", least_unacked: " << unacked_packets_.GetLeastUnacked() - << ", packets_acked_: " << packets_acked_; + << ", packets_acked_: " << quiche::PrintElements(packets_acked_); } else { QUIC_PEER_BUG(quic_peer_bug_10750_6) << "Received " << ack_decrypted_level diff --git a/gquiche/quic/core/quic_sent_packet_manager.h b/gquiche/quic/core/quic_sent_packet_manager.h index 2854fc0a..aa056fba 100644 --- a/gquiche/quic/core/quic_sent_packet_manager.h +++ b/gquiche/quic/core/quic_sent_packet_manager.h @@ -19,7 +19,6 @@ #include "gquiche/quic/core/congestion_control/send_algorithm_interface.h" #include "gquiche/quic/core/congestion_control/uber_loss_algorithm.h" #include "gquiche/quic/core/proto/cached_network_parameters_proto.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_sustained_bandwidth_recorder.h" #include "gquiche/quic/core/quic_time.h" @@ -28,6 +27,7 @@ #include "gquiche/quic/core/quic_unacked_packet_map.h" #include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -467,6 +467,13 @@ class QUIC_EXPORT_PRIVATE QuicSentPacketManager { QuicTime GetEarliestPacketSentTimeForPto( PacketNumberSpace* packet_number_space) const; + void set_num_ptos_for_path_degrading(int num_ptos_for_path_degrading) { + num_ptos_for_path_degrading_ = num_ptos_for_path_degrading; + } + + // Sets the initial RTT of the connection. + void SetInitialRtt(QuicTime::Delta rtt); + private: friend class test::QuicConnectionPeer; friend class test::QuicSentPacketManagerPeer; @@ -550,9 +557,6 @@ class QUIC_EXPORT_PRIVATE QuicSentPacketManager { // this function. void RecordOneSpuriousRetransmission(const QuicTransmissionInfo& info); - // Sets the initial RTT of the connection. - void SetInitialRtt(QuicTime::Delta rtt); - // Called when handshake is confirmed to remove the retransmittable frames // from all packets of HANDSHAKE_DATA packet number space to ensure they don't // get retransmitted and will eventually be removed from unacked packets map. @@ -618,6 +622,7 @@ class QUIC_EXPORT_PRIVATE QuicSentPacketManager { // Maximum number of packets to send upon RTO. QuicPacketCount max_rto_packets_; // If true, send the TLP at 0.5 RTT. + // TODO(renjietang): remove it once quic_deprecate_tlpr flag is deprecated. bool enable_half_rtt_tail_loss_probe_; bool using_pacing_; // If true, use the new RTO with loss based CWND reduction instead of the send @@ -673,7 +678,8 @@ class QUIC_EXPORT_PRIVATE QuicSentPacketManager { // The history of outstanding max_ack_delays sent to peer. Outstanding means // a max_ack_delay is sent as part of the last acked AckFrequencyFrame or // an unacked AckFrequencyFrame after that. - QuicCircularDeque> + quiche::QuicheCircularDeque< + std::pair> in_use_sent_ack_delays_; // Latest received ack frame. @@ -744,6 +750,9 @@ class QUIC_EXPORT_PRIVATE QuicSentPacketManager { // If true, do not use PING only packets for RTT measurement or congestion // control. bool ignore_pings_; + + // Whether to ignore the ack_delay in received ACKs. + bool ignore_ack_delay_; }; } // namespace quic diff --git a/gquiche/quic/core/quic_sent_packet_manager_test.cc b/gquiche/quic/core/quic_sent_packet_manager_test.cc index 48ec1e11..2f19e2f2 100644 --- a/gquiche/quic/core/quic_sent_packet_manager_test.cc +++ b/gquiche/quic/core/quic_sent_packet_manager_test.cc @@ -41,8 +41,7 @@ const QuicStreamId kStreamId = 7; // Matcher to check that the packet number matches the second argument. MATCHER(PacketNumberEq, "") { - return ::testing::get<0>(arg).packet_number == - QuicPacketNumber(::testing::get<1>(arg)); + return std::get<0>(arg).packet_number == QuicPacketNumber(std::get<1>(arg)); } class MockDebugDelegate : public QuicSentPacketManager::DebugDelegate { @@ -1465,7 +1464,8 @@ TEST_F(QuicSentPacketManagerTest, GetTransmissionTimeTailLossProbe) { } TEST_F(QuicSentPacketManagerTest, TLPRWithPendingStreamData) { - if (GetQuicReloadableFlag(quic_default_on_pto)) { + if (GetQuicReloadableFlag(quic_default_on_pto) || + GetQuicReloadableFlag(quic_deprecate_tlpr)) { return; } QuicConfig config; @@ -1518,7 +1518,8 @@ TEST_F(QuicSentPacketManagerTest, TLPRWithPendingStreamData) { } TEST_F(QuicSentPacketManagerTest, TLPRWithoutPendingStreamData) { - if (GetQuicReloadableFlag(quic_default_on_pto)) { + if (GetQuicReloadableFlag(quic_default_on_pto) || + GetQuicReloadableFlag(quic_deprecate_tlpr)) { return; } QuicConfig config; @@ -2125,7 +2126,8 @@ TEST_F(QuicSentPacketManagerTest, Negotiate1TLPFromOptionsAtClient) { } TEST_F(QuicSentPacketManagerTest, NegotiateTLPRttFromOptionsAtServer) { - if (GetQuicReloadableFlag(quic_default_on_pto)) { + if (GetQuicReloadableFlag(quic_default_on_pto) || + GetQuicReloadableFlag(quic_deprecate_tlpr)) { return; } QuicConfig config; @@ -2141,7 +2143,8 @@ TEST_F(QuicSentPacketManagerTest, NegotiateTLPRttFromOptionsAtServer) { } TEST_F(QuicSentPacketManagerTest, NegotiateTLPRttFromOptionsAtClient) { - if (GetQuicReloadableFlag(quic_default_on_pto)) { + if (GetQuicReloadableFlag(quic_default_on_pto) || + GetQuicReloadableFlag(quic_deprecate_tlpr)) { return; } QuicConfig client_config; @@ -4038,7 +4041,7 @@ TEST_F(QuicSentPacketManagerTest, GetPathDegradingDelay) { EXPECT_EQ(expected_delay, manager_.GetPathDegradingDelay()); } -TEST_F(QuicSentPacketManagerTest, GetPathDegradingDelayUsingPTO) { +TEST_F(QuicSentPacketManagerTest, GetPathDegradingDelayUsing2PTO) { QuicConfig client_config; QuicTagVector options; options.push_back(k1PTO); @@ -4055,6 +4058,23 @@ TEST_F(QuicSentPacketManagerTest, GetPathDegradingDelayUsingPTO) { EXPECT_EQ(expected_delay, manager_.GetPathDegradingDelay()); } +TEST_F(QuicSentPacketManagerTest, GetPathDegradingDelayUsing1PTO) { + QuicConfig client_config; + QuicTagVector options; + options.push_back(k1PTO); + QuicTagVector client_options; + client_options.push_back(kPDP1); + QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); + client_config.SetConnectionOptionsToSend(options); + client_config.SetClientConnectionOptions(client_options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + manager_.SetFromConfig(client_config); + EXPECT_TRUE(manager_.pto_enabled()); + QuicTime::Delta expected_delay = 1 * manager_.GetPtoDelay(); + EXPECT_EQ(expected_delay, manager_.GetPathDegradingDelay()); +} + TEST_F(QuicSentPacketManagerTest, ClientsIgnorePings) { QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); QuicConfig client_config; @@ -4352,6 +4372,9 @@ TEST_F(QuicSentPacketManagerTest, } TEST_F(QuicSentPacketManagerTest, ClientOnlyTLPRServer) { + if (GetQuicReloadableFlag(quic_deprecate_tlpr)) { + return; + } QuicConfig config; QuicTagVector options; @@ -4366,6 +4389,9 @@ TEST_F(QuicSentPacketManagerTest, ClientOnlyTLPRServer) { } TEST_F(QuicSentPacketManagerTest, ClientOnlyTLPR) { + if (GetQuicReloadableFlag(quic_deprecate_tlpr)) { + return; + } QuicSentPacketManagerPeer::SetPerspective(&manager_, Perspective::IS_CLIENT); QuicConfig config; QuicTagVector options; @@ -4380,6 +4406,9 @@ TEST_F(QuicSentPacketManagerTest, ClientOnlyTLPR) { } TEST_F(QuicSentPacketManagerTest, PtoWithTlpr) { + if (GetQuicReloadableFlag(quic_deprecate_tlpr)) { + return; + } QuicConfig config; QuicTagVector options; @@ -4637,8 +4666,7 @@ TEST_F(QuicSentPacketManagerTest, ClearDataInMessageFrameAfterPacketSent) { QuicMessageFrame* message_frame = nullptr; { QuicMemSlice slice(MakeUniqueBuffer(&allocator_, 1024), 1024); - message_frame = - new QuicMessageFrame(/*message_id=*/1, QuicMemSliceSpan(&slice)); + message_frame = new QuicMessageFrame(/*message_id=*/1, std::move(slice)); EXPECT_FALSE(message_frame->message_data.empty()); EXPECT_EQ(message_frame->message_length, 1024); @@ -4684,6 +4712,125 @@ TEST_F(QuicSentPacketManagerTest, BuildAckFrequencyFrame) { EXPECT_EQ(frame.packet_tolerance, 10u); } +TEST_F(QuicSentPacketManagerTest, SmoothedRttIgnoreAckDelay) { + QuicConfig config; + QuicTagVector options; + options.push_back(kMAD0); + QuicConfigPeer::SetReceivedConnectionOptions(&config, options); + EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); + EXPECT_CALL(*network_change_visitor_, OnCongestionChange()); + EXPECT_CALL(*send_algorithm_, CanSend(_)).WillRepeatedly(Return(true)); + EXPECT_CALL(*send_algorithm_, GetCongestionWindow()) + .WillRepeatedly(Return(10 * kDefaultTCPMSS)); + manager_.SetFromConfig(config); + + SendDataPacket(1); + // Ack 1. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(300)); + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), + QuicTime::Delta::FromMilliseconds(100), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL)); + // Verify that ack_delay is ignored in the first measurement. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->smoothed_rtt()); + + SendDataPacket(2); + // Ack 2. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(300)); + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), + QuicTime::Delta::FromMilliseconds(100), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_INITIAL)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->smoothed_rtt()); + + SendDataPacket(3); + // Ack 3. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(300)); + ExpectAck(3); + manager_.OnAckFrameStart(QuicPacketNumber(3), + QuicTime::Delta::FromMilliseconds(50), clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(3), QuicPacketNumber(4)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(3), + ENCRYPTION_INITIAL)); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(300), + manager_.GetRttStats()->smoothed_rtt()); + + SendDataPacket(4); + // Ack 4. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(200)); + ExpectAck(4); + manager_.OnAckFrameStart(QuicPacketNumber(4), + QuicTime::Delta::FromMilliseconds(300), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(4), QuicPacketNumber(5)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(4), + ENCRYPTION_INITIAL)); + // Verify that large erroneous ack_delay does not change Smoothed RTT. + EXPECT_EQ(QuicTime::Delta::FromMilliseconds(200), + manager_.GetRttStats()->latest_rtt()); + EXPECT_EQ(QuicTime::Delta::FromMicroseconds(287500), + manager_.GetRttStats()->smoothed_rtt()); +} + +TEST_F(QuicSentPacketManagerTest, IgnorePeerMaxAckDelayDuringHandshake) { + manager_.EnableMultiplePacketNumberSpacesSupport(); + // 100ms RTT. + const QuicTime::Delta kTestRTT = QuicTime::Delta::FromMilliseconds(100); + + // Server sends INITIAL 1 and HANDSHAKE 2. + SendDataPacket(1, ENCRYPTION_INITIAL); + SendDataPacket(2, ENCRYPTION_HANDSHAKE); + + // Receive client ACK for INITIAL 1 after one RTT. + clock_.AdvanceTime(kTestRTT); + ExpectAck(1); + manager_.OnAckFrameStart(QuicPacketNumber(1), QuicTime::Delta::Infinite(), + clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(1), QuicPacketNumber(2)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(1), + ENCRYPTION_INITIAL)); + EXPECT_EQ(kTestRTT, manager_.GetRttStats()->latest_rtt()); + + // Assume the cert verification on client takes 50ms, such that the HANDSHAKE + // packet is queued for 50ms. + const QuicTime::Delta queuing_delay = QuicTime::Delta::FromMilliseconds(50); + clock_.AdvanceTime(queuing_delay); + // Ack 2. + ExpectAck(2); + manager_.OnAckFrameStart(QuicPacketNumber(2), queuing_delay, clock_.Now()); + manager_.OnAckRange(QuicPacketNumber(2), QuicPacketNumber(3)); + EXPECT_EQ(PACKETS_NEWLY_ACKED, + manager_.OnAckFrameEnd(clock_.Now(), QuicPacketNumber(2), + ENCRYPTION_HANDSHAKE)); + if (GetQuicReloadableFlag(quic_ignore_peer_max_ack_delay_during_handshake)) { + EXPECT_EQ(kTestRTT, manager_.GetRttStats()->latest_rtt()); + } else { + // Verify the ack_delay gets capped by the peer_max_ack_delay. + EXPECT_EQ(kTestRTT + queuing_delay - + QuicTime::Delta::FromMilliseconds(kDefaultDelayedAckTimeMs), + manager_.GetRttStats()->latest_rtt()); + } +} + TEST_F(QuicSentPacketManagerTest, BuildAckFrequencyFrameWithSRTT) { SetQuicReloadableFlag(quic_can_send_ack_frequency, true); EXPECT_CALL(*send_algorithm_, SetFromConfig(_, _)); diff --git a/gquiche/quic/core/quic_server_id.cc b/gquiche/quic/core/quic_server_id.cc index fd77efc2..34707d32 100644 --- a/gquiche/quic/core/quic_server_id.cc +++ b/gquiche/quic/core/quic_server_id.cc @@ -7,8 +7,6 @@ #include #include -#include "gquiche/quic/platform/api/quic_estimate_memory_usage.h" - namespace quic { QuicServerId::QuicServerId() : QuicServerId("", 0, false) {} @@ -33,8 +31,8 @@ bool QuicServerId::operator==(const QuicServerId& other) const { host_ == other.host_ && port_ == other.port_; } -size_t QuicServerId::EstimateMemoryUsage() const { - return QuicEstimateMemoryUsage(host_); +bool QuicServerId::operator!=(const QuicServerId& other) const { + return !(*this == other); } } // namespace quic diff --git a/gquiche/quic/core/quic_server_id.h b/gquiche/quic/core/quic_server_id.h index 660d1342..a1a6aa42 100644 --- a/gquiche/quic/core/quic_server_id.h +++ b/gquiche/quic/core/quic_server_id.h @@ -8,6 +8,7 @@ #include #include +#include "absl/hash/hash.h" #include "gquiche/quic/platform/api/quic_export.h" namespace quic { @@ -23,24 +24,32 @@ class QUIC_EXPORT_PRIVATE QuicServerId { bool privacy_mode_enabled); ~QuicServerId(); - // Needed to be an element of std::set. + // Needed to be an element of an ordered container. bool operator<(const QuicServerId& other) const; bool operator==(const QuicServerId& other) const; + bool operator!=(const QuicServerId& other) const; + const std::string& host() const { return host_; } uint16_t port() const { return port_; } bool privacy_mode_enabled() const { return privacy_mode_enabled_; } - size_t EstimateMemoryUsage() const; - private: std::string host_; uint16_t port_; bool privacy_mode_enabled_; }; +class QUIC_EXPORT_PRIVATE QuicServerIdHash { + public: + size_t operator()(const quic::QuicServerId& server_id) const noexcept { + return absl::HashOf(server_id.host(), server_id.port(), + server_id.privacy_mode_enabled()); + } +}; + } // namespace quic #endif // QUICHE_QUIC_CORE_QUIC_SERVER_ID_H_ diff --git a/gquiche/quic/core/quic_server_id_test.cc b/gquiche/quic/core/quic_server_id_test.cc index 3b358c7b..7a6c6fa6 100644 --- a/gquiche/quic/core/quic_server_id_test.cc +++ b/gquiche/quic/core/quic_server_id_test.cc @@ -6,7 +6,6 @@ #include -#include "gquiche/quic/platform/api/quic_estimate_memory_usage.h" #include "gquiche/quic/platform/api/quic_test.h" namespace quic { @@ -87,6 +86,10 @@ TEST_F(QuicServerIdTest, Equals) { QuicServerId b_10_https_right_private("b.com", 10, right_privacy); QuicServerId b_11_https_right_private("b.com", 11, right_privacy); + EXPECT_NE(a_10_https_right_private, a_11_https_right_private); + EXPECT_NE(a_10_https_right_private, b_10_https_right_private); + EXPECT_NE(a_10_https_right_private, b_11_https_right_private); + QuicServerId new_a_10_https_left_private("a.com", 10, left_privacy); QuicServerId new_a_11_https_left_private("a.com", 11, left_privacy); QuicServerId new_b_10_https_left_private("b.com", 10, left_privacy); @@ -107,21 +110,13 @@ TEST_F(QuicServerIdTest, Equals) { QuicServerId new_a_10_https_left_private("a.com", 10, false); - EXPECT_FALSE(new_a_10_https_left_private == a_11_https_right_private); - EXPECT_FALSE(new_a_10_https_left_private == b_10_https_right_private); - EXPECT_FALSE(new_a_10_https_left_private == b_11_https_right_private); + EXPECT_NE(new_a_10_https_left_private, a_11_https_right_private); + EXPECT_NE(new_a_10_https_left_private, b_10_https_right_private); + EXPECT_NE(new_a_10_https_left_private, b_11_https_right_private); } QuicServerId a_10_https_private("a.com", 10, true); QuicServerId new_a_10_https_no_private("a.com", 10, false); - EXPECT_FALSE(new_a_10_https_no_private == a_10_https_private); -} - -TEST_F(QuicServerIdTest, EstimateMemoryUsage) { - std::string host = "this is a rather very quite long hostname"; - uint16_t port = 10; - bool privacy_mode_enabled = true; - QuicServerId server_id(host, port, privacy_mode_enabled); - EXPECT_EQ(QuicEstimateMemoryUsage(host), QuicEstimateMemoryUsage(server_id)); + EXPECT_NE(new_a_10_https_no_private, a_10_https_private); } } // namespace diff --git a/gquiche/quic/core/quic_session.cc b/gquiche/quic/core/quic_session.cc index 56d74906..dc2bb7c4 100644 --- a/gquiche/quic/core/quic_session.cc +++ b/gquiche/quic/core/quic_session.cc @@ -12,7 +12,9 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "gquiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "gquiche/quic/core/frames/quic_window_update_frame.h" #include "gquiche/quic/core/quic_connection.h" +#include "gquiche/quic/core/quic_connection_context.h" #include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_flow_controller.h" #include "gquiche/quic/core/quic_types.h" @@ -22,10 +24,9 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_server_stats.h" #include "gquiche/quic/platform/api/quic_stack_trace.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" using spdy::SpdyPriority; @@ -41,6 +42,12 @@ class ClosedStreamsCleanUpDelegate : public QuicAlarm::Delegate { ClosedStreamsCleanUpDelegate& operator=(const ClosedStreamsCleanUpDelegate&) = delete; + QuicConnectionContext* GetConnectionContext() override { + return (session_->connection() == nullptr) + ? nullptr + : session_->connection()->context(); + } + void OnAlarm() override { session_->CleanUpClosedStreams(); } private: @@ -53,22 +60,14 @@ class ClosedStreamsCleanUpDelegate : public QuicAlarm::Delegate { (perspective() == Perspective::IS_SERVER ? "Server: " : "Client: ") QuicSession::QuicSession( - QuicConnection* connection, - Visitor* owner, - const QuicConfig& config, + QuicConnection* connection, Visitor* owner, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, QuicStreamCount num_expected_unidirectional_static_streams) - : QuicSession(connection, - owner, - config, - supported_versions, - num_expected_unidirectional_static_streams, - nullptr) {} + : QuicSession(connection, owner, config, supported_versions, + num_expected_unidirectional_static_streams, nullptr) {} QuicSession::QuicSession( - QuicConnection* connection, - Visitor* owner, - const QuicConfig& config, + QuicConnection* connection, Visitor* owner, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, QuicStreamCount num_expected_unidirectional_static_streams, std::unique_ptr datagram_observer) @@ -77,14 +76,10 @@ QuicSession::QuicSession( visitor_(owner), write_blocked_streams_(connection->transport_version()), config_(config), - stream_id_manager_(perspective(), - connection->transport_version(), + stream_id_manager_(perspective(), connection->transport_version(), kDefaultMaxStreamsPerConnection, config_.GetMaxBidirectionalStreamsToSend()), - ietf_streamid_manager_(perspective(), - connection->version(), - this, - 0, + ietf_streamid_manager_(perspective(), connection->version(), this, 0, num_expected_unidirectional_static_streams, config_.GetMaxBidirectionalStreamsToSend(), config_.GetMaxUnidirectionalStreamsToSend() + @@ -94,15 +89,13 @@ QuicSession::QuicSession( num_static_streams_(0), num_zombie_streams_(0), flow_controller_( - this, - QuicUtils::GetInvalidStreamId(connection->transport_version()), + this, QuicUtils::GetInvalidStreamId(connection->transport_version()), /*is_connection_flow_controller*/ true, connection->version().AllowsLowFlowControlLimits() ? 0 : kMinimumFlowControlSendWindow, config_.GetInitialSessionFlowControlWindowToSend(), - kSessionReceiveWindowLimit, - perspective() == Perspective::IS_SERVER, + kSessionReceiveWindowLimit, perspective() == Perspective::IS_SERVER, nullptr), currently_writing_stream_id_(0), transport_goaway_sent_(false), @@ -135,11 +128,15 @@ void QuicSession::Initialize() { connection_->SetDataProducer(this); connection_->SetUnackedMapInitialCapacity(); connection_->SetFromConfig(config_); - if (perspective_ == Perspective::IS_CLIENT && - config_.HasClientRequestedIndependentOption(kAFFE, perspective_) && - version().HasIetfQuicFrames()) { - connection_->set_can_receive_ack_frequency_frame(); - config_.SetMinAckDelayMs(kDefaultMinAckDelayTimeMs); + if (perspective_ == Perspective::IS_CLIENT) { + if (config_.HasClientRequestedIndependentOption(kAFFE, perspective_) && + version().HasIetfQuicFrames()) { + connection_->set_can_receive_ack_frequency_frame(); + config_.SetMinAckDelayMs(kDefaultMinAckDelayTimeMs); + } + if (config_.HasClientRequestedIndependentOption(kNBPE, perspective_)) { + permutes_tls_extensions_ = false; + } } connection_->CreateConnectionIdManager(); @@ -162,9 +159,14 @@ void QuicSession::Initialize() { GetMutableCryptoStream()->id()); } -QuicSession::~QuicSession() {} +QuicSession::~QuicSession() { + if (closed_streams_clean_up_alarm_ != nullptr) { + closed_streams_clean_up_alarm_->PermanentCancel(); + } +} -void QuicSession::PendingStreamOnStreamFrame(const QuicStreamFrame& frame) { +PendingStream* QuicSession::PendingStreamOnStreamFrame( + const QuicStreamFrame& frame) { QUICHE_DCHECK(VersionUsesHttp3(transport_version())); QuicStreamId stream_id = frame.stream_id; @@ -175,27 +177,63 @@ void QuicSession::PendingStreamOnStreamFrame(const QuicStreamFrame& frame) { QuicStreamOffset final_byte_offset = frame.offset + frame.data_length; OnFinalByteOffsetReceived(stream_id, final_byte_offset); } - return; + return nullptr; } pending->OnStreamFrame(frame); if (!connection()->connected()) { - return; + return nullptr; } + return pending; +} + +void QuicSession::MaybeProcessPendingStream(PendingStream* pending) { + QUICHE_DCHECK(pending != nullptr); + QuicStreamId stream_id = pending->id(); + absl::optional stop_sending_error_code = + pending->GetStopSendingErrorCode(); QuicStream* stream = ProcessPendingStream(pending); if (stream != nullptr) { // The pending stream should now be in the scope of normal streams. QUICHE_DCHECK(IsClosedStream(stream_id) || IsOpenStream(stream_id)) << "Stream " << stream_id << " not created"; pending_stream_map_.erase(stream_id); + if (stop_sending_error_code) { + stream->OnStopSending(*stop_sending_error_code); + if (!connection()->connected()) { + return; + } + } stream->OnStreamCreatedFromPendingStream(); return; } + // At this point, none of the bytes has been successfully consumed by the + // application layer. We should close the pending stream even if it is + // bidirectionl as no application will be able to write in a bidirectional + // stream with zero byte as input. if (pending->sequencer()->IsClosed()) { ClosePendingStream(stream_id); } } +void QuicSession::PendingStreamOnWindowUpdateFrame( + const QuicWindowUpdateFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + PendingStream* pending = GetOrCreatePendingStream(frame.stream_id); + if (pending) { + pending->OnWindowUpdateFrame(frame); + } +} + +void QuicSession::PendingStreamOnStopSendingFrame( + const QuicStopSendingFrame& frame) { + QUICHE_DCHECK(VersionUsesHttp3(transport_version())); + PendingStream* pending = GetOrCreatePendingStream(frame.stream_id); + if (pending) { + pending->OnStopSending(frame.error()); + } +} + void QuicSession::OnStreamFrame(const QuicStreamFrame& frame) { QuicStreamId stream_id = frame.stream_id; if (stream_id == QuicUtils::GetInvalidStreamId(transport_version())) { @@ -205,12 +243,11 @@ void QuicSession::OnStreamFrame(const QuicStreamFrame& frame) { return; } - if (UsesPendingStreams() && - QuicUtils::GetStreamType(stream_id, perspective(), - IsIncomingStream(stream_id), - version()) == READ_UNIDIRECTIONAL && - stream_map_.find(stream_id) == stream_map_.end()) { - PendingStreamOnStreamFrame(frame); + if (ShouldProcessFrameByPendingStream(STREAM_FRAME, stream_id)) { + PendingStream* pending = PendingStreamOnStreamFrame(frame); + if (pending != nullptr && ShouldProcessPendingStreamImmediately()) { + MaybeProcessPendingStream(pending); + } return; } @@ -269,6 +306,10 @@ void QuicSession::OnStopSendingFrame(const QuicStopSendingFrame& frame) { if (visitor_) { visitor_->OnStopSendingReceived(frame); } + if (ShouldProcessFrameByPendingStream(STOP_SENDING_FRAME, stream_id)) { + PendingStreamOnStopSendingFrame(frame); + return; + } QuicStream* stream = GetOrCreateStream(stream_id); if (!stream) { @@ -276,7 +317,7 @@ void QuicSession::OnStopSendingFrame(const QuicStopSendingFrame& frame) { return; } - stream->OnStopSending(frame.error_code); + stream->OnStopSending(frame.error()); } void QuicSession::OnPacketDecrypted(EncryptionLevel level) { @@ -316,11 +357,10 @@ void QuicSession::PendingStreamOnRstStream(const QuicRstStreamFrame& frame) { } pending->OnRstStreamFrame(frame); - // Pending stream is currently read only. We can safely close the stream. - QUICHE_DCHECK_EQ( - READ_UNIDIRECTIONAL, - QuicUtils::GetStreamType(pending->id(), perspective(), - /*peer_initiated = */ true, version())); + // At this point, none of the bytes has been consumed by the application + // layer. It is safe to close the pending stream even if it is bidirectionl as + // no application will be able to write in a bidirectional stream with zero + // byte as input. ClosePendingStream(stream_id); } @@ -347,11 +387,7 @@ void QuicSession::OnRstStream(const QuicRstStreamFrame& frame) { visitor_->OnRstStreamReceived(frame); } - if (UsesPendingStreams() && - QuicUtils::GetStreamType(stream_id, perspective(), - IsIncomingStream(stream_id), - version()) == READ_UNIDIRECTIONAL && - stream_map_.find(stream_id) == stream_map_.end()) { + if (ShouldProcessFrameByPendingStream(RST_STREAM_FRAME, stream_id)) { PendingStreamOnRstStream(frame); return; } @@ -435,7 +471,7 @@ void QuicSession::OnConnectionClosed(const QuicConnectionCloseFrame& frame, closed_streams_clean_up_alarm_->Cancel(); if (visitor_) { - visitor_->OnConnectionClosed(connection_->connection_id(), + visitor_->OnConnectionClosed(connection_->GetOneActiveServerConnectionId(), frame.quic_error_code, frame.error_details, source); } @@ -474,9 +510,7 @@ void QuicSession::OnPathDegrading() {} void QuicSession::OnForwardProgressMadeAfterPathDegrading() {} -bool QuicSession::AllowSelfAddressChange() const { - return false; -} +bool QuicSession::AllowSelfAddressChange() const { return false; } void QuicSession::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) { // Stream may be closed by the time we receive a WINDOW_UPDATE, so we can't @@ -504,6 +538,11 @@ void QuicSession::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) { return; } + if (ShouldProcessFrameByPendingStream(WINDOW_UPDATE_FRAME, stream_id)) { + PendingStreamOnWindowUpdateFrame(frame); + return; + } + QuicStream* stream = GetOrCreateStream(stream_id); if (stream != nullptr) { stream->OnWindowUpdateFrame(frame); @@ -563,17 +602,14 @@ bool QuicSession::CheckStreamWriteBlocked(QuicStream* stream) const { } void QuicSession::OnCanWrite() { - if (connection_->donot_write_mid_packet_processing()) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_donot_write_mid_packet_processing, 1, 3); - if (connection_->framer().is_processing_packet()) { - // Do not write data in the middle of packet processing because rest - // frames in the packet may change the data to write. For example, lost - // data could be acknowledged. Also, connection is going to emit - // OnCanWrite signal post packet processing. - QUIC_BUG(session_write_mid_packet_processing) - << ENDPOINT << "Try to write mid packet processing."; - return; - } + if (connection_->framer().is_processing_packet()) { + // Do not write data in the middle of packet processing because rest + // frames in the packet may change the data to write. For example, lost + // data could be acknowledged. Also, connection is going to emit + // OnCanWrite signal post packet processing. + QUIC_BUG(session_write_mid_packet_processing) + << ENDPOINT << "Try to write mid packet processing."; + return; } if (!RetransmitLostData()) { // Cannot finish retransmitting lost data, connection is write blocked. @@ -686,11 +722,8 @@ bool QuicSession::WillingAndAbleToWrite() const { if (HasPendingHandshake()) { return true; } - if (GetQuicReloadableFlag(quic_fix_willing_and_able_to_write2)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_fix_willing_and_able_to_write2); - if (!IsEncryptionEstablished()) { - return false; - } + if (!IsEncryptionEstablished()) { + return false; } } if (control_frame_manager_.WillingToWrite() || @@ -742,8 +775,8 @@ bool QuicSession::HasPendingHandshake() const { return GetCryptoStream()->HasPendingCryptoRetransmission() || GetCryptoStream()->HasBufferedCryptoFrames(); } - return QuicContainsKey(streams_with_pending_retransmission_, - QuicUtils::GetCryptoStreamId(transport_version())) || + return streams_with_pending_retransmission_.contains( + QuicUtils::GetCryptoStreamId(transport_version())) || write_blocked_streams_.IsStreamBlocked( QuicUtils::GetCryptoStreamId(transport_version())); } @@ -751,19 +784,17 @@ bool QuicSession::HasPendingHandshake() const { void QuicSession::ProcessUdpPacket(const QuicSocketAddress& self_address, const QuicSocketAddress& peer_address, const QuicReceivedPacket& packet) { + QuicConnectionContextSwitcher cs(connection_->context()); connection_->ProcessUdpPacket(self_address, peer_address, packet); } -QuicConsumedData QuicSession::WritevData( - QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, - TransmissionType type, - absl::optional level) { +QuicConsumedData QuicSession::WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state, + TransmissionType type, + EncryptionLevel level) { QUICHE_DCHECK(connection_->connected()) << ENDPOINT << "Try to write stream data when connection is closed."; - QUICHE_DCHECK(!use_write_or_buffer_data_at_level_ || level.has_value()); if (!IsEncryptionEstablished() && !QuicUtils::IsCryptoStreamId(transport_version(), id)) { // Do not let streams write without encryption. The calling stream will end @@ -771,26 +802,31 @@ QuicConsumedData QuicSession::WritevData( if (was_zero_rtt_rejected_ && !OneRttKeysAvailable()) { QUICHE_DCHECK(version().UsesTls() && perspective() == Perspective::IS_CLIENT); - QUIC_BUG_IF(quic_bug_12435_3, type == NOT_RETRANSMISSION) - << ENDPOINT << "Try to send new data on stream " << id - << "before 1-RTT keys are available while 0-RTT is rejected."; + QUIC_DLOG(INFO) << ENDPOINT + << "Suppress the write while 0-RTT gets rejected and " + "1-RTT keys are not available. Version: " + << ParsedQuicVersionToString(version()); + } else if (version().UsesTls() || perspective() == Perspective::IS_SERVER) { + QUIC_BUG(quic_bug_10866_2) + << ENDPOINT << "Try to send data of stream " << id + << " before encryption is established. Version: " + << ParsedQuicVersionToString(version()); } else { - QUIC_BUG(quic_bug_10866_2) << ENDPOINT << "Try to send data of stream " - << id << " before encryption is established."; + // In QUIC crypto, this could happen when the client sends full CHLO and + // 0-RTT request, then receives an inchoate REJ and sends an inchoate + // CHLO. The client then gets the ACK of the inchoate CHLO or the client + // gets the full REJ and needs to verify the proof (before it sends the + // full CHLO), such that there is no outstanding crypto data. + // Retransmission alarm fires in TLP mode which tries to retransmit the + // 0-RTT request (without encryption). + QUIC_DLOG(INFO) << ENDPOINT << "Try to send data of stream " << id + << " before encryption is established."; } return QuicConsumedData(0, false); } SetTransmissionType(type); - const auto current_level = connection()->encryption_level(); - if (!use_encryption_level_context()) { - if (level.has_value()) { - connection()->SetDefaultEncryptionLevel(level.value()); - } - } - QuicConnection::ScopedEncryptionLevelContext context( - use_encryption_level_context() ? connection() : nullptr, - use_encryption_level_context() ? level.value() : NUM_ENCRYPTION_LEVELS); + QuicConnection::ScopedEncryptionLevelContext context(connection(), level); QuicConsumedData data = connection_->SendStreamData(id, write_length, offset, state); @@ -799,19 +835,10 @@ QuicConsumedData QuicSession::WritevData( write_blocked_streams_.UpdateBytesForStream(id, data.bytes_consumed); } - // Restore the encryption level. - if (!use_encryption_level_context()) { - // Restore the encryption level. - if (level.has_value()) { - connection()->SetDefaultEncryptionLevel(current_level); - } - } - return data; } -size_t QuicSession::SendCryptoData(EncryptionLevel level, - size_t write_length, +size_t QuicSession::SendCryptoData(EncryptionLevel level, size_t write_length, QuicStreamOffset offset, TransmissionType type) { QUICHE_DCHECK(QuicVersionUsesCryptoFrames(transport_version())); @@ -826,18 +853,9 @@ size_t QuicSession::SendCryptoData(EncryptionLevel level, return 0; } SetTransmissionType(type); - const auto current_level = connection()->encryption_level(); - if (!use_encryption_level_context()) { - connection_->SetDefaultEncryptionLevel(level); - } - QuicConnection::ScopedEncryptionLevelContext context( - use_encryption_level_context() ? connection() : nullptr, level); + QuicConnection::ScopedEncryptionLevelContext context(connection(), level); const auto bytes_consumed = connection_->SendCryptoData(level, write_length, offset); - if (!use_encryption_level_context()) { - // Restores encryption level. - connection_->SetDefaultEncryptionLevel(current_level); - } return bytes_consumed; } @@ -852,21 +870,13 @@ bool QuicSession::WriteControlFrame(const QuicFrame& frame, TransmissionType type) { QUICHE_DCHECK(connection()->connected()) << ENDPOINT << "Try to write control frames when connection is closed."; - if (connection_->encrypted_control_frames()) { - QUIC_RELOADABLE_FLAG_COUNT(quic_encrypted_control_frames); - if (!IsEncryptionEstablished()) { - QUIC_BUG(quic_bug_10866_4) - << ENDPOINT << "Tried to send control frame " << frame - << " before encryption is established. Last decrypted level: " - << EncryptionLevelToString(connection_->last_decrypted_level()); - return false; - } + if (!IsEncryptionEstablished()) { + // Suppress the write before encryption gets established. + return false; } SetTransmissionType(type); QuicConnection::ScopedEncryptionLevelContext context( - use_encryption_level_context() ? connection() : nullptr, - use_encryption_level_context() ? GetEncryptionLevelToSendApplicationData() - : NUM_ENCRYPTION_LEVELS); + connection(), GetEncryptionLevelToSendApplicationData()); return connection_->SendControlFrame(frame); } @@ -885,12 +895,12 @@ void QuicSession::ResetStream(QuicStreamId id, QuicRstStreamErrorCode error) { } QuicConnection::ScopedPacketFlusher flusher(connection()); - MaybeSendStopSendingFrame(id, error); - MaybeSendRstStreamFrame(id, error, 0); + MaybeSendStopSendingFrame(id, QuicResetStreamError::FromInternal(error)); + MaybeSendRstStreamFrame(id, QuicResetStreamError::FromInternal(error), 0); } void QuicSession::MaybeSendRstStreamFrame(QuicStreamId id, - QuicRstStreamErrorCode error, + QuicResetStreamError error, QuicStreamOffset bytes_written) { if (!connection()->connected()) { return; @@ -901,11 +911,11 @@ void QuicSession::MaybeSendRstStreamFrame(QuicStreamId id, control_frame_manager_.WriteOrBufferRstStream(id, error, bytes_written); } - connection_->OnStreamReset(id, error); + connection_->OnStreamReset(id, error.internal_code()); } void QuicSession::MaybeSendStopSendingFrame(QuicStreamId id, - QuicRstStreamErrorCode error) { + QuicResetStreamError error) { if (!connection()->connected()) { return; } @@ -920,15 +930,12 @@ void QuicSession::SendGoAway(QuicErrorCode error_code, const std::string& reason) { // GOAWAY frame is not supported in IETF QUIC. QUICHE_DCHECK(!VersionHasIetfQuicFrames(transport_version())); - if (GetQuicReloadableFlag(quic_encrypted_goaway)) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_encrypted_goaway, 1, 2); - if (!IsEncryptionEstablished()) { - QUIC_CODE_COUNT(quic_goaway_before_encryption_established); - connection_->CloseConnection( - error_code, reason, - ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); - return; - } + if (!IsEncryptionEstablished()) { + QUIC_CODE_COUNT(quic_goaway_before_encryption_established); + connection_->CloseConnection( + error_code, reason, + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + return; } if (transport_goaway_sent_) { return; @@ -978,8 +985,7 @@ void QuicSession::SendMaxStreams(QuicStreamCount stream_count, } void QuicSession::InsertLocallyClosedStreamsHighestOffset( - const QuicStreamId id, - QuicStreamOffset offset) { + const QuicStreamId id, QuicStreamOffset offset) { locally_closed_streams_highest_offset_[id] = offset; } @@ -1067,9 +1073,14 @@ void QuicSession::ClosePendingStream(QuicStreamId stream_id) { } } +bool QuicSession::ShouldProcessFrameByPendingStream(QuicFrameType type, + QuicStreamId id) const { + return UsesPendingStreamForFrame(type, id) && + stream_map_.find(id) == stream_map_.end(); +} + void QuicSession::OnFinalByteOffsetReceived( - QuicStreamId stream_id, - QuicStreamOffset final_byte_offset) { + QuicStreamId stream_id, QuicStreamOffset final_byte_offset) { auto it = locally_closed_streams_highest_offset_.find(stream_id); if (it == locally_closed_streams_highest_offset_.end()) { return; @@ -1321,8 +1332,7 @@ void QuicSession::OnConfigNegotiated() { // Or if this session is configured on TLS enabled QUIC versions, // attempt to retransmit 0-RTT data if there's any. // TODO(fayang): consider removing this OnCanWrite call. - if ((!connection_->donot_write_mid_packet_processing() || - !connection_->framer().is_processing_packet()) && + if (!connection_->framer().is_processing_packet() && (connection_->version().AllowsLowFlowControlLimits() || version().UsesTls())) { QUIC_CODE_COUNT(quic_session_on_can_write_on_config_negotiated); @@ -1331,8 +1341,7 @@ void QuicSession::OnConfigNegotiated() { } absl::optional QuicSession::OnAlpsData( - const uint8_t* /*alps_data*/, - size_t /*alps_length*/) { + const uint8_t* /*alps_data*/, size_t /*alps_length*/) { return absl::nullopt; } @@ -1560,10 +1569,8 @@ void QuicSession::OnNewSessionFlowControlWindow(QuicStreamOffset new_window) { } bool QuicSession::OnNewDecryptionKeyAvailable( - EncryptionLevel level, - std::unique_ptr decrypter, - bool set_alternative_decrypter, - bool latch_once_used) { + EncryptionLevel level, std::unique_ptr decrypter, + bool set_alternative_decrypter, bool latch_once_used) { if (connection_->version().handshake_protocol == PROTOCOL_TLS1_3 && !connection()->framer().HasEncrypterOfEncryptionLevel( QuicUtils::GetEncryptionLevel( @@ -1586,8 +1593,7 @@ bool QuicSession::OnNewDecryptionKeyAvailable( } void QuicSession::OnNewEncryptionKeyAvailable( - EncryptionLevel level, - std::unique_ptr encrypter) { + EncryptionLevel level, std::unique_ptr encrypter) { connection()->SetEncrypter(level, std::move(encrypter)); if (connection_->version().handshake_protocol != PROTOCOL_TLS1_3) { return; @@ -1629,8 +1635,7 @@ void QuicSession::SetDefaultEncryptionLevel(EncryptionLevel level) { // Retransmit old 0-RTT data (if any) with the new 0-RTT keys, since // they can't be decrypted by the server. connection_->MarkZeroRttPacketsForRetransmission(0); - if (!connection_->donot_write_mid_packet_processing() || - !connection_->framer().is_processing_packet()) { + if (!connection_->framer().is_processing_packet()) { // TODO(fayang): consider removing this OnCanWrite call. // Given any streams blocked by encryption a chance to write. QUIC_CODE_COUNT( @@ -1666,21 +1671,24 @@ void QuicSession::OnTlsHandshakeComplete() { // Server sends HANDSHAKE_DONE to signal confirmation of the handshake // to the client. control_frame_manager_.WriteOrBufferHandshakeDone(); - if (GetQuicReloadableFlag(quic_enable_token_based_address_validation) && - connection()->version().HasIetfQuicFrames()) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_enable_token_based_address_validation, - 1, 2); + if (connection()->version().HasIetfQuicFrames()) { MaybeSendAddressToken(); } } } -void QuicSession::MaybeSendAddressToken() { +bool QuicSession::MaybeSendAddressToken() { QUICHE_DCHECK(perspective_ == Perspective::IS_SERVER && connection()->version().HasIetfQuicFrames()); - std::string address_token = GetCryptoStream()->GetAddressToken(); + absl::optional cached_network_params; + if (add_cached_network_parameters_to_address_token()) { + cached_network_params = GenerateCachedNetworkParameters(); + } + std::string address_token = GetCryptoStream()->GetAddressToken( + cached_network_params.has_value() ? &cached_network_params.value() + : nullptr); if (address_token.empty()) { - return; + return false; } const size_t buf_len = address_token.length() + 1; auto buffer = std::make_unique(buf_len); @@ -1690,6 +1698,13 @@ void QuicSession::MaybeSendAddressToken() { writer.WriteBytes(address_token.data(), address_token.length()); control_frame_manager_.WriteOrBufferNewToken( absl::string_view(buffer.get(), buf_len)); + if (add_cached_network_parameters_to_address_token() && + cached_network_params.has_value()) { + connection()->OnSendConnectionState(*cached_network_params); + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_add_cached_network_parameters_to_address_token2, 1, 2); + } + return true; } void QuicSession::DiscardOldDecryptionKey(EncryptionLevel level) { @@ -1758,8 +1773,7 @@ bool QuicSession::FillTransportParameters(TransportParameters* params) { } QuicErrorCode QuicSession::ProcessTransportParameters( - const TransportParameters& params, - bool is_resumption, + const TransportParameters& params, bool is_resumption, std::string* error_details) { return config_.ProcessTransportParameters(params, is_resumption, error_details); @@ -1787,8 +1801,7 @@ void QuicSession::OnCryptoHandshakeMessageReceived( const CryptoHandshakeMessage& /*message*/) {} void QuicSession::RegisterStreamPriority( - QuicStreamId id, - bool is_static, + QuicStreamId id, bool is_static, const spdy::SpdyStreamPrecedence& precedence) { write_blocked_streams()->RegisterStream(id, is_static, precedence); } @@ -1798,21 +1811,18 @@ void QuicSession::UnregisterStreamPriority(QuicStreamId id, bool is_static) { } void QuicSession::UpdateStreamPriority( - QuicStreamId id, - const spdy::SpdyStreamPrecedence& new_precedence) { + QuicStreamId id, const spdy::SpdyStreamPrecedence& new_precedence) { write_blocked_streams()->UpdateStreamPriority(id, new_precedence); } -QuicConfig* QuicSession::config() { - return &config_; -} +QuicConfig* QuicSession::config() { return &config_; } void QuicSession::ActivateStream(std::unique_ptr stream) { QuicStreamId stream_id = stream->id(); bool is_static = stream->is_static(); QUIC_DVLOG(1) << ENDPOINT << "num_streams: " << stream_map_.size() << ". activating stream " << stream_id; - QUICHE_DCHECK(!QuicContainsKey(stream_map_, stream_id)); + QUICHE_DCHECK(!stream_map_.contains(stream_id)); stream_map_[stream_id] = std::move(stream); if (is_static) { ++num_static_streams_; @@ -1892,7 +1902,7 @@ QuicStreamCount QuicSession::GetAdvertisedMaxIncomingBidirectionalStreams() } QuicStream* QuicSession::GetOrCreateStream(const QuicStreamId stream_id) { - QUICHE_DCHECK(!QuicContainsKey(pending_stream_map_, stream_id)); + QUICHE_DCHECK(!pending_stream_map_.contains(stream_id)); if (QuicUtils::IsCryptoStreamId(transport_version(), stream_id)) { return GetMutableCryptoStream(); } @@ -1931,7 +1941,7 @@ QuicStream* QuicSession::GetOrCreateStream(const QuicStreamId stream_id) { } void QuicSession::StreamDraining(QuicStreamId stream_id, bool unidirectional) { - QUICHE_DCHECK(QuicContainsKey(stream_map_, stream_id)); + QUICHE_DCHECK(stream_map_.contains(stream_id)); QUIC_DVLOG(1) << ENDPOINT << "Stream " << stream_id << " is draining"; if (VersionHasIetfQuicFrames(transport_version())) { ietf_streamid_manager_.OnStreamClosed(stream_id); @@ -1942,7 +1952,9 @@ void QuicSession::StreamDraining(QuicStreamId stream_id, bool unidirectional) { ++num_draining_streams_; if (!IsIncomingStream(stream_id)) { ++num_outgoing_draining_streams_; - OnCanCreateNewOutgoingStream(unidirectional); + if (!VersionHasIetfQuicFrames(transport_version())) { + OnCanCreateNewOutgoingStream(unidirectional); + } } } @@ -2016,8 +2028,7 @@ void QuicSession::DeleteConnection() { } bool QuicSession::MaybeSetStreamPriority( - QuicStreamId stream_id, - const spdy::SpdyStreamPrecedence& precedence) { + QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence) { auto active_stream = stream_map_.find(stream_id); if (active_stream != stream_map_.end()) { active_stream->second->SetPriority(precedence); @@ -2047,7 +2058,7 @@ bool QuicSession::IsOpenStream(QuicStreamId id) { if (it != stream_map_.end()) { return !it->second->IsZombie(); } - if (QuicContainsKey(pending_stream_map_, id) || + if (pending_stream_map_.contains(id) || QuicUtils::IsCryptoStreamId(transport_version(), id)) { // Stream is active return true; @@ -2101,24 +2112,32 @@ void QuicSession::SendAckFrequency(const QuicAckFrequencyFrame& frame) { } void QuicSession::SendNewConnectionId(const QuicNewConnectionIdFrame& frame) { + // Count NEW_CONNECTION_ID frames sent to client. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 1, 6); control_frame_manager_.WriteOrBufferNewConnectionId( frame.connection_id, frame.sequence_number, frame.retire_prior_to, frame.stateless_reset_token); } void QuicSession::SendRetireConnectionId(uint64_t sequence_number) { + // Count RETIRE_CONNECTION_ID frames sent to client. + QUIC_RELOADABLE_FLAG_COUNT_N(quic_connection_migration_use_new_cid_v2, 2, 6); control_frame_manager_.WriteOrBufferRetireConnectionId(sequence_number); } void QuicSession::OnServerConnectionIdIssued( const QuicConnectionId& server_connection_id) { - visitor_->OnNewConnectionIdSent(connection_->connection_id(), - server_connection_id); + if (visitor_) { + visitor_->OnNewConnectionIdSent( + connection_->GetOneActiveServerConnectionId(), server_connection_id); + } } void QuicSession::OnServerConnectionIdRetired( const QuicConnectionId& server_connection_id) { - visitor_->OnConnectionIdRetired(server_connection_id); + if (visitor_) { + visitor_->OnConnectionIdRetired(server_connection_id); + } } bool QuicSession::IsConnectionFlowControlBlocked() const { @@ -2265,8 +2284,8 @@ void QuicSession::OnFrameLost(const QuicFrame& frame) { frame.stream_frame.data_length, frame.stream_frame.fin); if (stream->HasPendingRetransmission() && - !QuicContainsKey(streams_with_pending_retransmission_, - frame.stream_frame.stream_id)) { + !streams_with_pending_retransmission_.contains( + frame.stream_frame.stream_id)) { streams_with_pending_retransmission_.insert( std::make_pair(frame.stream_frame.stream_id, true)); } @@ -2390,8 +2409,8 @@ bool QuicSession::RetransmitLostData() { } // Retransmit crypto data in stream 1 frames (version < 47). if (!uses_crypto_frames && - QuicContainsKey(streams_with_pending_retransmission_, - QuicUtils::GetCryptoStreamId(transport_version()))) { + streams_with_pending_retransmission_.contains( + QuicUtils::GetCryptoStreamId(transport_version()))) { // Retransmit crypto data first. QuicStream* crypto_stream = GetStream(QuicUtils::GetCryptoStreamId(transport_version())); @@ -2456,20 +2475,23 @@ void QuicSession::SetTransmissionType(TransmissionType type) { connection_->SetTransmissionType(type); } -MessageResult QuicSession::SendMessage(QuicMemSliceSpan message) { +MessageResult QuicSession::SendMessage(absl::Span message) { return SendMessage(message, /*flush=*/false); } -MessageResult QuicSession::SendMessage(QuicMemSliceSpan message, bool flush) { +MessageResult QuicSession::SendMessage(QuicMemSlice message) { + return SendMessage(absl::MakeSpan(&message, 1), /*flush=*/false); +} + +MessageResult QuicSession::SendMessage(absl::Span message, + bool flush) { QUICHE_DCHECK(connection_->connected()) << ENDPOINT << "Try to write messages when connection is closed."; if (!IsEncryptionEstablished()) { return {MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED, 0}; } QuicConnection::ScopedEncryptionLevelContext context( - use_encryption_level_context() ? connection() : nullptr, - use_encryption_level_context() ? GetEncryptionLevelToSendApplicationData() - : NUM_ENCRYPTION_LEVELS); + connection(), GetEncryptionLevelToSendApplicationData()); MessageStatus result = connection_->SendMessage(last_message_id_ + 1, message, flush); if (result == MESSAGE_STATUS_SUCCESS) { @@ -2488,9 +2510,7 @@ void QuicSession::OnMessageLost(QuicMessageId message_id) { << " is considered lost"; } -void QuicSession::CleanUpClosedStreams() { - closed_streams_.clear(); -} +void QuicSession::CleanUpClosedStreams() { closed_streams_.clear(); } QuicPacketLength QuicSession::GetCurrentLargestMessagePayload() const { return connection_->GetCurrentLargestMessagePayload(); @@ -2599,6 +2619,21 @@ EncryptionLevel QuicSession::GetEncryptionLevelToSendApplicationData() const { return connection_->framer().GetEncryptionLevelToSendApplicationData(); } +void QuicSession::ProcessAllPendingStreams() { + std::vector pending_streams; + pending_streams.reserve(pending_stream_map_.size()); + for (auto it = pending_stream_map_.cbegin(); it != pending_stream_map_.cend(); + ++it) { + pending_streams.push_back(it->second.get()); + } + for (auto* pending_stream : pending_streams) { + MaybeProcessPendingStream(pending_stream); + if (!connection()->connected()) { + return; + } + } +} + void QuicSession::ValidatePath( std::unique_ptr context, std::unique_ptr result_delegate) { @@ -2611,20 +2646,33 @@ bool QuicSession::HasPendingPathValidation() const { bool QuicSession::MigratePath(const QuicSocketAddress& self_address, const QuicSocketAddress& peer_address, - QuicPacketWriter* writer, - bool owns_writer) { + QuicPacketWriter* writer, bool owns_writer) { return connection_->MigratePath(self_address, peer_address, writer, owns_writer); } -bool QuicSession::ValidateToken(absl::string_view token) const { +bool QuicSession::ValidateToken(absl::string_view token) { QUICHE_DCHECK_EQ(perspective_, Perspective::IS_SERVER); + if (GetQuicFlag(FLAGS_quic_reject_retry_token_in_initial_packet)) { + return false; + } if (token.empty() || token[0] != 0) { // Validate the prefix for token received in NEW_TOKEN frame. return false; } - return GetCryptoStream()->ValidateAddressToken( + const bool valid = GetCryptoStream()->ValidateAddressToken( absl::string_view(token.data() + 1, token.length() - 1)); + if (add_cached_network_parameters_to_address_token() && valid) { + const CachedNetworkParameters* cached_network_params = + GetCryptoStream()->PreviousCachedNetworkParams(); + if (cached_network_params != nullptr && + cached_network_params->timestamp() > 0) { + connection()->OnReceiveConnectionState(*cached_network_params); + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_add_cached_network_parameters_to_address_token2, 2, 2); + } + } + return valid; } #undef ENDPOINT // undef for jumbo builds diff --git a/gquiche/quic/core/quic_session.h b/gquiche/quic/core/quic_session.h index 82debb97..67e6c804 100644 --- a/gquiche/quic/core/quic_session.h +++ b/gquiche/quic/core/quic_session.h @@ -17,9 +17,14 @@ #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "absl/types/span.h" +#include "gquiche/quic/core/crypto/tls_connection.h" #include "gquiche/quic/core/frames/quic_ack_frequency_frame.h" +#include "gquiche/quic/core/frames/quic_stop_sending_frame.h" +#include "gquiche/quic/core/frames/quic_window_update_frame.h" #include "gquiche/quic/core/handshaker_delegate_interface.h" #include "gquiche/quic/core/legacy_quic_stream_id_manager.h" +#include "gquiche/quic/core/proto/cached_network_parameters_proto.h" #include "gquiche/quic/core/quic_connection.h" #include "gquiche/quic/core/quic_control_frame_manager.h" #include "gquiche/quic/core/quic_crypto_stream.h" @@ -35,10 +40,11 @@ #include "gquiche/quic/core/session_notifier_interface.h" #include "gquiche/quic/core/stream_delegate_interface.h" #include "gquiche/quic/core/uber_quic_stream_id_manager.h" -#include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_flags.h" +#include "gquiche/quic/platform/api/quic_mem_slice.h" #include "gquiche/quic/platform/api/quic_socket_address.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -94,13 +100,11 @@ class QUIC_EXPORT_PRIVATE QuicSession }; // Does not take ownership of |connection| or |visitor|. - QuicSession(QuicConnection* connection, - Visitor* owner, + QuicSession(QuicConnection* connection, Visitor* owner, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, QuicStreamCount num_expected_unidirectional_static_streams); - QuicSession(QuicConnection* connection, - Visitor* owner, + QuicSession(QuicConnection* connection, Visitor* owner, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, QuicStreamCount num_expected_unidirectional_static_streams, @@ -168,8 +172,8 @@ class QUIC_EXPORT_PRIVATE QuicSession override; std::unique_ptr CreateCurrentOneRttEncrypter() override; void BeforeConnectionCloseSent() override {} - bool ValidateToken(absl::string_view token) const override; - void MaybeSendAddressToken() override; + bool ValidateToken(absl::string_view token) override; + bool MaybeSendAddressToken() override; bool IsKnownServerAddress( const QuicSocketAddress& /*address*/) const override { return false; @@ -180,14 +184,12 @@ class QUIC_EXPORT_PRIVATE QuicSession QuicStreamOffset offset, QuicByteCount data_length, QuicDataWriter* writer) override; - bool WriteCryptoData(EncryptionLevel level, - QuicStreamOffset offset, + bool WriteCryptoData(EncryptionLevel level, QuicStreamOffset offset, QuicByteCount data_length, QuicDataWriter* writer) override; // SessionNotifierInterface methods: - bool OnFrameAcked(const QuicFrame& frame, - QuicTime::Delta ack_delay_time, + bool OnFrameAcked(const QuicFrame& frame, QuicTime::Delta ack_delay_time, QuicTime receive_timestamp) override; void OnStreamFrameRetransmitted(const QuicStreamFrame& frame) override; void OnFrameLost(const QuicFrame& frame) override; @@ -208,31 +210,39 @@ class QUIC_EXPORT_PRIVATE QuicSession const QuicSocketAddress& peer_address, const QuicReceivedPacket& packet); - // Called by application to send |message|. Data copy can be avoided if - // |message| is provided in reference counted memory. - // Please note, |message| provided in reference counted memory would be moved - // internally when message is successfully sent. Thereafter, it would be - // undefined behavior if callers try to access the slices through their own - // copy of the span object. - // Returns the message result which includes the message status and message ID - // (valid if the write succeeds). SendMessage flushes a message packet even it - // is not full. If the application wants to bundle other data in the same - // packet, please consider adding a packet flusher around the SendMessage - // and/or WritevData calls. + // Sends |message| as a QUIC DATAGRAM frame (QUIC MESSAGE frame in gQUIC). + // See for + // more details. // - // OnMessageAcked and OnMessageLost are called when a particular message gets - // acked or lost. + // Returns a MessageResult struct which includes the status of the write + // operation and a message ID. The message ID (not sent on the wire) can be + // used to track the message; OnMessageAcked and OnMessageLost are called when + // a specific message gets acked or lost. + // + // If the write operation is successful, all of the slices in |message| are + // consumed, leaving them empty. If MESSAGE_STATUS_INTERNAL_ERROR is + // returned, the slices in question may or may not be consumed; it is no + // longer safe to access those. For all other status codes, |message| is kept + // intact. // // Note that SendMessage will fail with status = MESSAGE_STATUS_BLOCKED - // if connection is congestion control blocked or underlying socket is write - // blocked. In this case the caller can retry sending message again when + // if the connection is congestion control blocked or the underlying socket is + // write blocked. In this case the caller can retry sending message again when // connection becomes available, for example after getting OnCanWrite() // callback. - MessageResult SendMessage(QuicMemSliceSpan message); + // + // SendMessage flushes the current packet even it is not full; if the + // application needs to bundle other data in the same packet, consider using + // QuicConnection::ScopedPacketFlusher around the relevant write operations. + MessageResult SendMessage(absl::Span message); // Same as above SendMessage, except caller can specify if the given |message| // should be flushed even if the underlying connection is deemed unwritable. - MessageResult SendMessage(QuicMemSliceSpan message, bool flush); + MessageResult SendMessage(absl::Span message, bool flush); + + // Single-slice version of SendMessage(). Unlike the version above, this + // version always takes ownership of the slice. + MessageResult SendMessage(QuicMemSlice message); // Called when message with |message_id| gets acked. virtual void OnMessageAcked(QuicMessageId message_id, @@ -288,8 +298,7 @@ class QUIC_EXPORT_PRIVATE QuicSession bool set_alternative_decrypter, bool latch_once_used) override; void OnNewEncryptionKeyAvailable( - EncryptionLevel level, - std::unique_ptr encrypter) override; + EncryptionLevel level, std::unique_ptr encrypter) override; void SetDefaultEncryptionLevel(EncryptionLevel level) override; void OnTlsHandshakeComplete() override; void DiscardOldDecryptionKey(EncryptionLevel level) override; @@ -312,8 +321,7 @@ class QUIC_EXPORT_PRIVATE QuicSession std::string error_details) override; // Sets priority in the write blocked list. void RegisterStreamPriority( - QuicStreamId id, - bool is_static, + QuicStreamId id, bool is_static, const spdy::SpdyStreamPrecedence& precedence) override; // Clears priority from the write blocked list. void UnregisterStreamPriority(QuicStreamId id, bool is_static) override; @@ -327,15 +335,12 @@ class QUIC_EXPORT_PRIVATE QuicSession // indicating if the fin bit was consumed. This does not indicate the data // has been sent on the wire: it may have been turned into a packet and queued // if the socket was unexpectedly blocked. - QuicConsumedData WritevData(QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, TransmissionType type, - absl::optional level) override; + EncryptionLevel level) override; - size_t SendCryptoData(EncryptionLevel level, - size_t write_length, + size_t SendCryptoData(EncryptionLevel level, size_t write_length, QuicStreamOffset offset, TransmissionType type) override; @@ -457,8 +462,7 @@ class QUIC_EXPORT_PRIVATE QuicSession // Switch to the path described in |context| without validating the path. bool MigratePath(const QuicSocketAddress& self_address, const QuicSocketAddress& peer_address, - QuicPacketWriter* writer, - bool owns_writer); + QuicPacketWriter* writer, bool owns_writer); // Returns the largest payload that will fit into a single MESSAGE frame. // Because overhead can vary during a connection, this method should be @@ -579,12 +583,12 @@ class QUIC_EXPORT_PRIVATE QuicSession // Does actual work of sending RESET_STREAM, if the stream type allows. // Also informs the connection so that pending stream frames can be flushed. virtual void MaybeSendRstStreamFrame(QuicStreamId id, - QuicRstStreamErrorCode error, + QuicResetStreamError error, QuicStreamOffset bytes_written); // Sends a STOP_SENDING frame if the stream type allows. virtual void MaybeSendStopSendingFrame(QuicStreamId id, - QuicRstStreamErrorCode error); + QuicResetStreamError error); // Returns the encryption level to send application data. EncryptionLevel GetEncryptionLevelToSendApplicationData() const; @@ -593,9 +597,10 @@ class QUIC_EXPORT_PRIVATE QuicSession return user_agent_id_; } + // TODO(wub): remove saving user-agent to QuicSession. void SetUserAgentId(std::string user_agent_id) { user_agent_id_ = std::move(user_agent_id); - connection()->OnUserAgentIdKnown(); + connection()->OnUserAgentIdKnown(user_agent_id_.value()); } void SetSourceAddressTokenToSend(absl::string_view token) { @@ -610,13 +615,29 @@ class QUIC_EXPORT_PRIVATE QuicSession return liveness_testing_in_progress_; } - bool use_write_or_buffer_data_at_level() const { - return use_write_or_buffer_data_at_level_; + bool permutes_tls_extensions() const { return permutes_tls_extensions_; } + + virtual QuicSSLConfig GetSSLConfig() const { return QuicSSLConfig(); } + + // Latched value of flag --quic_tls_server_support_client_cert. + bool support_client_cert() const { return support_client_cert_; } + + // Get latched flag value. + bool add_cached_network_parameters_to_address_token() const { + return add_cached_network_parameters_to_address_token_; } - bool use_encryption_level_context() const { - return connection_->use_encryption_level_context() && - use_write_or_buffer_data_at_level_; + // Try converting all pending streams to normal streams. + void ProcessAllPendingStreams(); + + const ParsedQuicVersionVector& client_original_supported_versions() const { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + return client_original_supported_versions_; + } + void set_client_original_supported_versions( + const ParsedQuicVersionVector& client_original_supported_versions) { + QUICHE_DCHECK_EQ(perspective_, Perspective::IS_CLIENT); + client_original_supported_versions_ = client_original_supported_versions; } protected: @@ -671,11 +692,17 @@ class QUIC_EXPORT_PRIVATE QuicSession virtual void OnFinalByteOffsetReceived(QuicStreamId id, QuicStreamOffset final_byte_offset); - // Returns true if incoming unidirectional streams should be buffered until - // the first byte of the stream arrives. - // If a subclass returns true here, it should make sure to implement - // ProcessPendingStream(). - virtual bool UsesPendingStreams() const { return false; } + // Returns true if a frame with the given type and id can be prcoessed by a + // PendingStream. However, the frame will always be processed by a QuicStream + // if one exists with the given stream_id. + virtual bool UsesPendingStreamForFrame(QuicFrameType /*type*/, + QuicStreamId /*stream_id*/) const { + return false; + } + + // Returns true if a pending stream should be converted to a real stream after + // a corresponding STREAM_FRAME is received. + virtual bool ShouldProcessPendingStreamImmediately() const { return true; } spdy::SpdyPriority GetSpdyPriorityofStream(QuicStreamId stream_id) const { return write_blocked_streams_.GetSpdyPriorityofStream(stream_id); @@ -777,6 +804,20 @@ class QUIC_EXPORT_PRIVATE QuicSession // streams. QuicStream* GetActiveStream(QuicStreamId id) const; + const UberQuicStreamIdManager& ietf_streamid_manager() const { + QUICHE_DCHECK(VersionHasIetfQuicFrames(transport_version())); + return ietf_streamid_manager_; + } + + // Only called at a server session. Generate a CachedNetworkParameters that + // can be sent to the client as part of the address token, based on the latest + // bandwidth/rtt information. If return absl::nullopt, address token will not + // contain the CachedNetworkParameters. + virtual absl::optional + GenerateCachedNetworkParameters() const { + return absl::nullopt; + } + private: friend class test::QuicSessionPeer; @@ -821,6 +862,7 @@ class QUIC_EXPORT_PRIVATE QuicSession // closed. QuicStream* GetStream(QuicStreamId id) const; + // Can return NULL, e.g., if the stream has been closed before. PendingStream* GetOrCreatePendingStream(QuicStreamId stream_id); // Let streams and control frame managers retransmit lost data, returns true @@ -833,14 +875,30 @@ class QUIC_EXPORT_PRIVATE QuicSession // Closes the pending stream |stream_id| before it has been created. void ClosePendingStream(QuicStreamId stream_id); - // Creates or gets pending stream, feeds it with |frame|, and processes the - // pending stream. - void PendingStreamOnStreamFrame(const QuicStreamFrame& frame); + // Whether the frame with given type and id should be feed to a pending + // stream. + bool ShouldProcessFrameByPendingStream(QuicFrameType type, + QuicStreamId id) const; + + // Process the pending stream if possible. + void MaybeProcessPendingStream(PendingStream* pending); + + // Creates or gets pending stream, feeds it with |frame|, and returns the + // pending stream. Can return NULL, e.g., if the stream ID is invalid. + PendingStream* PendingStreamOnStreamFrame(const QuicStreamFrame& frame); // Creates or gets pending strea, feed it with |frame|, and closes the pending // stream. void PendingStreamOnRstStream(const QuicRstStreamFrame& frame); + // Creates or gets pending stream, feeds it with |frame|, and records the + // max_data in the pending stream. + void PendingStreamOnWindowUpdateFrame(const QuicWindowUpdateFrame& frame); + + // Creates or gets pending stream, feeds it with |frame|, and records the + // ietf_error_code in the pending stream. + void PendingStreamOnStopSendingFrame(const QuicStopSendingFrame& frame); + // Keep track of highest received byte offset of locally closed streams, while // waiting for a definitive final highest offset from the peer. absl::flat_hash_map @@ -925,7 +983,8 @@ class QUIC_EXPORT_PRIVATE QuicSession // TODO(fayang): switch to linked_hash_set when chromium supports it. The bool // is not used here. // List of streams with pending retransmissions. - QuicLinkedHashMap streams_with_pending_retransmission_; + quiche::QuicheLinkedHashMap + streams_with_pending_retransmission_; // Clean up closed_streams_ when this alarm fires. std::unique_ptr closed_streams_clean_up_alarm_; @@ -934,6 +993,11 @@ class QUIC_EXPORT_PRIVATE QuicSession // list may be a superset of the connection framer's supported versions. ParsedQuicVersionVector supported_versions_; + // Only non-empty on the client after receiving a version negotiation packet, + // contains the configured versions from the original session before version + // negotiation was received. + ParsedQuicVersionVector client_original_supported_versions_; + absl::optional user_agent_id_; // Initialized to false. Set to true when the session has been properly @@ -947,8 +1011,15 @@ class QUIC_EXPORT_PRIVATE QuicSession // creation of new outgoing bidirectional streams. bool liveness_testing_in_progress_; - const bool use_write_or_buffer_data_at_level_ = - GetQuicReloadableFlag(quic_use_write_or_buffer_data_at_level); + const bool add_cached_network_parameters_to_address_token_ = + GetQuicReloadableFlag( + quic_add_cached_network_parameters_to_address_token2); + + // Whether BoringSSL randomizes the order of TLS extensions. + bool permutes_tls_extensions_ = true; + + const bool support_client_cert_ = + GetQuicRestartFlag(quic_tls_server_support_client_cert); }; } // namespace quic diff --git a/gquiche/quic/core/quic_session_test.cc b/gquiche/quic/core/quic_session_test.cc index e94a6e57..47d6d7f9 100644 --- a/gquiche/quic/core/quic_session_test.cc +++ b/gquiche/quic/core/quic_session_test.cc @@ -23,14 +23,13 @@ #include "gquiche/quic/core/quic_data_writer.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_stream.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_expect_bug.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_mem_slice_storage.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/quic/platform/api/quic_test_mem_slice_vector.h" #include "gquiche/quic/test_tools/mock_quic_session_visitor.h" #include "gquiche/quic/test_tools/quic_config_peer.h" #include "gquiche/quic/test_tools/quic_connection_peer.h" @@ -142,10 +141,19 @@ class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { void OnHandshakePacketSent() override {} void OnHandshakeDoneReceived() override {} void OnNewTokenReceived(absl::string_view /*token*/) override {} - std::string GetAddressToken() const override { return ""; } + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } bool ValidateAddressToken(absl::string_view /*token*/) const override { return true; } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} HandshakeState GetHandshakeState() const override { return one_rtt_keys_available() ? HANDSHAKE_COMPLETE : HANDSHAKE_START; } @@ -169,6 +177,15 @@ class TestCryptoStream : public QuicCryptoStream, public QuicCryptoHandshaker { void OnConnectionClosed(QuicErrorCode /*error*/, ConnectionCloseSource /*source*/) override {} + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + + SSL* GetSsl() const override { return nullptr; } + private: using QuicCryptoStream::session; @@ -188,8 +205,8 @@ class TestStream : public QuicStream { StreamType type) : QuicStream(id, session, is_static, type) {} - TestStream(PendingStream* pending, QuicSession* session, StreamType type) - : QuicStream(pending, session, type, /*is_static=*/false) {} + TestStream(PendingStream* pending, QuicSession* session) + : QuicStream(pending, session, /*is_static=*/false) {} using QuicStream::CloseWriteSide; using QuicStream::WriteMemSlices; @@ -280,11 +297,7 @@ class TestSession : public QuicSession { } TestStream* CreateIncomingStream(PendingStream* pending) override { - QuicStreamId id = pending->id(); - TestStream* stream = new TestStream( - pending, this, - DetermineStreamType(id, connection()->version(), perspective(), - /*is_incoming=*/true, BIDIRECTIONAL)); + TestStream* stream = new TestStream(pending, this); ActivateStream(absl::WrapUnique(stream)); ++num_incoming_streams_created_; return stream; @@ -294,6 +307,9 @@ class TestSession : public QuicSession { // test that the session handles pending streams correctly in terms of // receiving stream frames. QuicStream* ProcessPendingStream(PendingStream* pending) override { + if (pending->is_bidirectional()) { + return CreateIncomingStream(pending); + } struct iovec iov; if (pending->sequencer()->GetReadableRegion(&iov)) { // Create TestStream once the first byte is received. @@ -314,12 +330,10 @@ class TestSession : public QuicSession { return GetNumActiveStreams() > 0; } - QuicConsumedData WritevData(QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, TransmissionType type, - absl::optional level) override { + EncryptionLevel level) override { bool fin = state != NO_FIN; QuicConsumedData consumed(write_length, fin); if (!writev_consumes_all_data_) { @@ -371,12 +385,45 @@ class TestSession : public QuicSession { GetEncryptionLevelToSendApplicationData()); } - bool UsesPendingStreams() const override { return uses_pending_streams_; } + bool UsesPendingStreamForFrame(QuicFrameType type, + QuicStreamId stream_id) const override { + if (!uses_pending_streams_) { + return false; + } + // Uses pending stream for STREAM/RST_STREAM frames with unidirectional read + // stream and uses pending stream for + // STREAM/RST_STREAM/STOP_SENDING/WINDOW_UPDATE frames with bidirectional + // stream. + bool is_incoming_stream = IsIncomingStream(stream_id); + StreamType stream_type = QuicUtils::GetStreamType( + stream_id, perspective(), is_incoming_stream, version()); + switch (type) { + case STREAM_FRAME: + ABSL_FALLTHROUGH_INTENDED; + case RST_STREAM_FRAME: + return is_incoming_stream; + case STOP_SENDING_FRAME: + ABSL_FALLTHROUGH_INTENDED; + case WINDOW_UPDATE_FRAME: + return stream_type == BIDIRECTIONAL; + default: + return false; + } + } + + bool ShouldProcessPendingStreamImmediately() const override { + return process_pending_stream_immediately_; + } void set_uses_pending_streams(bool uses_pending_streams) { uses_pending_streams_ = uses_pending_streams; } + void set_process_pending_stream_immediately( + bool process_pending_stream_immediately) { + process_pending_stream_immediately_ = process_pending_stream_immediately; + } + int num_incoming_streams_created() const { return num_incoming_streams_created_; } @@ -393,6 +440,7 @@ class TestSession : public QuicSession { bool writev_consumes_all_data_; bool uses_pending_streams_; + bool process_pending_stream_immediately_ = true; QuicFrame save_frame_; int num_incoming_streams_created_; }; @@ -452,7 +500,7 @@ class QuicSessionTestBase : public QuicTestWithParam { QuicUtils::GetCryptoStreamId(connection_->transport_version()); } for (QuicStreamId i = first_stream_id; i < 100; i++) { - if (!QuicContainsKey(closed_streams_, i)) { + if (closed_streams_.find(i) == closed_streams_.end()) { EXPECT_FALSE(session_.IsClosedStream(i)) << " stream id: " << i; } else { EXPECT_TRUE(session_.IsClosedStream(i)) << " stream id: " << i; @@ -1529,56 +1577,6 @@ TEST_P(QuicSessionTestServer, HandshakeUnblocksFlowControlBlockedStream) { EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); } -TEST_P(QuicSessionTestServer, HandshakeUnblocksFlowControlBlockedCryptoStream) { - if (QuicVersionUsesCryptoFrames(GetParam().transport_version) || - connection_->encrypted_control_frames()) { - // QUIC version 47 onwards uses CRYPTO frames for the handshake, so this - // test doesn't make sense for those versions since CRYPTO frames aren't - // flow controlled. - return; - } - // Test that if the crypto stream is flow control blocked, then if the SHLO - // contains a larger send window offset, the stream becomes unblocked. - session_.set_writev_consumes_all_data(true); - TestCryptoStream* crypto_stream = session_.GetMutableCryptoStream(); - EXPECT_FALSE(crypto_stream->IsFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); - EXPECT_CALL(*connection_, SendControlFrame(_)) - .WillOnce(Invoke(&ClearControlFrame)); - for (QuicStreamId i = 0; !crypto_stream->IsFlowControlBlocked() && i < 1000u; - i++) { - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); - QuicStreamOffset offset = crypto_stream->stream_bytes_written(); - QuicConfig config; - CryptoHandshakeMessage crypto_message; - config.ToHandshakeMessage(&crypto_message, transport_version()); - crypto_stream->SendHandshakeMessage(crypto_message, ENCRYPTION_INITIAL); - char buf[1000]; - QuicDataWriter writer(1000, buf, quiche::NETWORK_BYTE_ORDER); - crypto_stream->WriteStreamData(offset, crypto_message.size(), &writer); - } - EXPECT_TRUE(crypto_stream->IsFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_TRUE(session_.IsStreamFlowControlBlocked()); - EXPECT_FALSE(session_.HasDataToWrite()); - EXPECT_TRUE(crypto_stream->HasBufferedData()); - - // Now complete the crypto handshake, resulting in an increased flow control - // send window. - CompleteHandshake(); - EXPECT_TRUE(QuicSessionPeer::IsStreamWriteBlocked( - &session_, - QuicUtils::GetCryptoStreamId(connection_->transport_version()))); - // Stream is now unblocked and will no longer have buffered data. - EXPECT_FALSE(crypto_stream->IsFlowControlBlocked()); - EXPECT_FALSE(session_.IsConnectionFlowControlBlocked()); - EXPECT_FALSE(session_.IsStreamFlowControlBlocked()); -} - TEST_P(QuicSessionTestServer, ConnectionFlowControlAccountingRstOutOfOrder) { CompleteHandshake(); // Test that when we receive an out of order stream RST we correctly adjust @@ -1814,7 +1812,9 @@ TEST_P(QuicSessionTestServer, DrainingStreamsDoNotCountAsOpenedOutgoing) { QuicStreamId stream_id = stream->id(); QuicStreamFrame data1(stream_id, true, 0, absl::string_view("HT")); session_.OnStreamFrame(data1); - EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(1); + if (!VersionHasIetfQuicFrames(transport_version())) { + EXPECT_CALL(session_, OnCanCreateNewOutgoingStream(false)).Times(1); + } session_.StreamDraining(stream_id, /*unidirectional=*/false); } @@ -1837,6 +1837,7 @@ TEST_P(QuicSessionTestServer, PendingStreams) { return; } session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(true); QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( transport_version(), Perspective::IS_CLIENT); @@ -1851,11 +1852,50 @@ TEST_P(QuicSessionTestServer, PendingStreams) { EXPECT_EQ(1, session_.num_incoming_streams_created()); } +TEST_P(QuicSessionTestServer, BufferAllIncomingStreams) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); + + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + // Read unidirectional stream is still buffered when the first byte arrives. + QuicStreamFrame data2(stream_id, false, 0, absl::string_view("HT")); + session_.OnStreamFrame(data2); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + // Bidirectional stream is buffered. + QuicStreamId bidirectional_stream_id = + QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT); + QuicStreamFrame data3(bidirectional_stream_id, false, 0, + absl::string_view("HT")); + session_.OnStreamFrame(data3); + EXPECT_TRUE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + session_.ProcessAllPendingStreams(); + // Both bidirectional and read-unidirectional streams are unbuffered. + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_FALSE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(2, session_.num_incoming_streams_created()); +} + TEST_P(QuicSessionTestServer, RstPendingStreams) { if (!VersionUsesHttp3(transport_version())) { return; } session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( transport_version(), Perspective::IS_CLIENT); @@ -1877,6 +1917,27 @@ TEST_P(QuicSessionTestServer, RstPendingStreams) { EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); EXPECT_EQ(0, session_.num_incoming_streams_created()); EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + session_.ProcessAllPendingStreams(); + // Bidirectional stream is buffered. + QuicStreamId bidirectional_stream_id = + QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT); + QuicStreamFrame data3(bidirectional_stream_id, false, 0, + absl::string_view("HT")); + session_.OnStreamFrame(data3); + EXPECT_TRUE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + // Bidirectional pending stream is removed after RST_STREAM is received. + QuicRstStreamFrame rst2(kInvalidControlFrameId, bidirectional_stream_id, + QUIC_ERROR_PROCESSING_STREAM, 12); + session_.OnRstStream(rst2); + EXPECT_FALSE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); } TEST_P(QuicSessionTestServer, OnFinPendingStreams) { @@ -1884,6 +1945,7 @@ TEST_P(QuicSessionTestServer, OnFinPendingStreams) { return; } session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(true); QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( transport_version(), Perspective::IS_CLIENT); @@ -1893,9 +1955,30 @@ TEST_P(QuicSessionTestServer, OnFinPendingStreams) { EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); EXPECT_EQ(0, session_.num_incoming_streams_created()); EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(&session_)); + + session_.set_process_pending_stream_immediately(false); + // Bidirectional pending stream remains after Fin is received. + // Bidirectional stream is buffered. + QuicStreamId bidirectional_stream_id = + QuicUtils::GetFirstBidirectionalStreamId(transport_version(), + Perspective::IS_CLIENT); + QuicStreamFrame data2(bidirectional_stream_id, true, 0, + absl::string_view("HT")); + session_.OnStreamFrame(data2); + EXPECT_TRUE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + session_.ProcessAllPendingStreams(); + EXPECT_FALSE( + QuicSessionPeer::GetPendingStream(&session_, bidirectional_stream_id)); + EXPECT_EQ(1, session_.num_incoming_streams_created()); + QuicStream* bidirectional_stream = + QuicSessionPeer::GetStream(&session_, bidirectional_stream_id); + EXPECT_TRUE(bidirectional_stream->fin_received()); } -TEST_P(QuicSessionTestServer, PendingStreamOnWindowUpdate) { +TEST_P(QuicSessionTestServer, UnidirectionalPendingStreamOnWindowUpdate) { if (!VersionUsesHttp3(transport_version())) { return; } @@ -1917,6 +2000,80 @@ TEST_P(QuicSessionTestServer, PendingStreamOnWindowUpdate) { session_.OnWindowUpdateFrame(window_update_frame); } +TEST_P(QuicSessionTestServer, BidirectionalPendingStreamOnWindowUpdate) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data); + QuicWindowUpdateFrame window_update_frame(kInvalidControlFrameId, stream_id, + kDefaultFlowControlSendWindow * 2); + session_.OnWindowUpdateFrame(window_update_frame); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + session_.ProcessAllPendingStreams(); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(1, session_.num_incoming_streams_created()); + QuicStream* bidirectional_stream = + QuicSessionPeer::GetStream(&session_, stream_id); + QuicByteCount send_window = + QuicStreamPeer::SendWindowSize(bidirectional_stream); + EXPECT_EQ(send_window, kDefaultFlowControlSendWindow * 2); +} + +TEST_P(QuicSessionTestServer, UnidirectionalPendingStreamOnStopSending) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.set_uses_pending_streams(true); + QuicStreamId stream_id = QuicUtils::GetFirstUnidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data1(stream_id, true, 10, absl::string_view("HT")); + session_.OnStreamFrame(data1); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + QuicStopSendingFrame stop_sending_frame(kInvalidControlFrameId, stream_id, + QUIC_STREAM_CANCELLED); + EXPECT_CALL( + *connection_, + CloseConnection(QUIC_INVALID_STREAM_ID, + "Received STOP_SENDING for a read-only stream", _)); + session_.OnStopSendingFrame(stop_sending_frame); +} + +TEST_P(QuicSessionTestServer, BidirectionalPendingStreamOnStopSending) { + if (!VersionUsesHttp3(transport_version())) { + return; + } + + session_.set_uses_pending_streams(true); + session_.set_process_pending_stream_immediately(false); + QuicStreamId stream_id = QuicUtils::GetFirstBidirectionalStreamId( + transport_version(), Perspective::IS_CLIENT); + QuicStreamFrame data(stream_id, true, 0, absl::string_view("HT")); + session_.OnStreamFrame(data); + QuicStopSendingFrame stop_sending_frame(kInvalidControlFrameId, stream_id, + QUIC_STREAM_CANCELLED); + session_.OnStopSendingFrame(stop_sending_frame); + EXPECT_TRUE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(0, session_.num_incoming_streams_created()); + + EXPECT_CALL(*connection_, OnStreamReset(stream_id, _)); + session_.ProcessAllPendingStreams(); + EXPECT_FALSE(QuicSessionPeer::GetPendingStream(&session_, stream_id)); + EXPECT_EQ(1, session_.num_incoming_streams_created()); + QuicStream* bidirectional_stream = + QuicSessionPeer::GetStream(&session_, stream_id); + EXPECT_TRUE(bidirectional_stream->write_side_closed()); +} + TEST_P(QuicSessionTestServer, DrainingStreamsDoNotCountAsOpened) { // Verify that a draining stream (which has received a FIN but not consumed // it) does not count against the open quota (because it is closed from the @@ -2451,36 +2608,26 @@ TEST_P(QuicSessionTestServer, RetransmitLostDataCausesConnectionClose) { TEST_P(QuicSessionTestServer, SendMessage) { // Cannot send message when encryption is not established. EXPECT_FALSE(session_.OneRttKeysAvailable()); - quic::QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); EXPECT_EQ(MessageResult(MESSAGE_STATUS_ENCRYPTION_NOT_ESTABLISHED, 0), - session_.SendMessage( - MakeSpan(connection_->helper()->GetStreamSendBufferAllocator(), - "", &storage))); + session_.SendMessage(MemSliceFromString(""))); CompleteHandshake(); EXPECT_TRUE(session_.OneRttKeysAvailable()); - absl::string_view message; EXPECT_CALL(*connection_, SendMessage(1, _, false)) .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); EXPECT_EQ(MessageResult(MESSAGE_STATUS_SUCCESS, 1), - session_.SendMessage( - MakeSpan(connection_->helper()->GetStreamSendBufferAllocator(), - message, &storage))); + session_.SendMessage(MemSliceFromString(""))); // Verify message_id increases. EXPECT_CALL(*connection_, SendMessage(2, _, false)) .WillOnce(Return(MESSAGE_STATUS_TOO_LARGE)); EXPECT_EQ(MessageResult(MESSAGE_STATUS_TOO_LARGE, 0), - session_.SendMessage( - MakeSpan(connection_->helper()->GetStreamSendBufferAllocator(), - message, &storage))); + session_.SendMessage(MemSliceFromString(""))); // Verify unsent message does not consume a message_id. EXPECT_CALL(*connection_, SendMessage(2, _, false)) .WillOnce(Return(MESSAGE_STATUS_SUCCESS)); EXPECT_EQ(MessageResult(MESSAGE_STATUS_SUCCESS, 2), - session_.SendMessage( - MakeSpan(connection_->helper()->GetStreamSendBufferAllocator(), - message, &storage))); + session_.SendMessage(MemSliceFromString(""))); QuicMessageFrame frame(1); QuicMessageFrame frame2(2); @@ -2507,7 +2654,7 @@ TEST_P(QuicSessionTestServer, LocallyResetZombieStreams) { EXPECT_TRUE(stream2->IsWaitingForAcks()); // Verify stream2 is a zombie streams. auto& stream_map = QuicSessionPeer::stream_map(&session_); - ASSERT_TRUE(QuicContainsKey(stream_map, stream2->id())); + ASSERT_TRUE(stream_map.contains(stream2->id())); auto* stream = stream_map.find(stream2->id())->second.get(); EXPECT_TRUE(stream->IsZombie()); @@ -2556,7 +2703,7 @@ TEST_P(QuicSessionTestServer, WriteUnidirectionalStream) { stream4->WriteOrBufferData(body, false, nullptr); stream4->WriteOrBufferData(body, true, nullptr); auto& stream_map = QuicSessionPeer::stream_map(&session_); - ASSERT_TRUE(QuicContainsKey(stream_map, stream4->id())); + ASSERT_TRUE(stream_map.contains(stream4->id())); auto* stream = stream_map.find(stream4->id())->second.get(); EXPECT_TRUE(stream->IsZombie()); } @@ -2634,12 +2781,11 @@ TEST_P(QuicSessionTestServer, WriteMemSlicesOnReadUnidirectionalStream) { CloseConnection( QUIC_TRY_TO_WRITE_DATA_ON_READ_UNIDIRECTIONAL_STREAM, _, _)) .Times(1); - char data[1024]; - std::vector> buffers; - buffers.push_back(std::make_pair(data, ABSL_ARRAYSIZE(data))); - buffers.push_back(std::make_pair(data, ABSL_ARRAYSIZE(data))); - QuicTestMemSliceVector vector(buffers); - stream4->WriteMemSlices(vector.span(), false); + std::string data(1024, 'a'); + std::vector buffers; + buffers.push_back(MemSliceFromString(data)); + buffers.push_back(MemSliceFromString(data)); + stream4->WriteMemSlices(absl::MakeSpan(buffers), false); } // Test code that tests that an incoming stream frame with a new (not previously diff --git a/gquiche/quic/core/quic_simple_buffer_allocator.h b/gquiche/quic/core/quic_simple_buffer_allocator.h index 71aea3cd..b7639cef 100644 --- a/gquiche/quic/core/quic_simple_buffer_allocator.h +++ b/gquiche/quic/core/quic_simple_buffer_allocator.h @@ -12,6 +12,11 @@ namespace quic { class QUIC_EXPORT_PRIVATE SimpleBufferAllocator : public QuicBufferAllocator { public: + static SimpleBufferAllocator* Get() { + static SimpleBufferAllocator* singleton = new SimpleBufferAllocator(); + return singleton; + } + char* New(size_t size) override; char* New(size_t size, bool flag_enable) override; void Delete(char* buffer) override; diff --git a/gquiche/quic/core/quic_simple_buffer_allocator_test.cc b/gquiche/quic/core/quic_simple_buffer_allocator_test.cc index 0158f355..a4b12086 100644 --- a/gquiche/quic/core/quic_simple_buffer_allocator_test.cc +++ b/gquiche/quic/core/quic_simple_buffer_allocator_test.cc @@ -24,5 +24,43 @@ TEST_F(SimpleBufferAllocatorTest, DeleteNull) { alloc.Delete(nullptr); } +TEST_F(SimpleBufferAllocatorTest, MoveBuffersConstructor) { + SimpleBufferAllocator alloc; + QuicBuffer buffer1(&alloc, 16); + + EXPECT_NE(buffer1.data(), nullptr); + EXPECT_EQ(buffer1.size(), 16u); + + QuicBuffer buffer2(std::move(buffer1)); + EXPECT_EQ(buffer1.data(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(buffer1.size(), 0u); + EXPECT_NE(buffer2.data(), nullptr); + EXPECT_EQ(buffer2.size(), 16u); +} + +TEST_F(SimpleBufferAllocatorTest, MoveBuffersAssignment) { + SimpleBufferAllocator alloc; + QuicBuffer buffer1(&alloc, 16); + QuicBuffer buffer2; + + EXPECT_NE(buffer1.data(), nullptr); + EXPECT_EQ(buffer1.size(), 16u); + EXPECT_EQ(buffer2.data(), nullptr); + EXPECT_EQ(buffer2.size(), 0u); + + buffer2 = std::move(buffer1); + EXPECT_EQ(buffer1.data(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(buffer1.size(), 0u); + EXPECT_NE(buffer2.data(), nullptr); + EXPECT_EQ(buffer2.size(), 16u); +} + +TEST_F(SimpleBufferAllocatorTest, CopyBuffer) { + SimpleBufferAllocator alloc; + const absl::string_view original = "Test string"; + QuicBuffer copy = QuicBuffer::Copy(&alloc, original); + EXPECT_EQ(copy.AsStringView(), original); +} + } // namespace } // namespace quic diff --git a/gquiche/quic/core/quic_stream.cc b/gquiche/quic/core/quic_stream.cc index 8eb9772b..3867cabc 100644 --- a/gquiche/quic/core/quic_stream.cc +++ b/gquiche/quic/core/quic_stream.cc @@ -15,10 +15,13 @@ #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" +#include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/quic/platform/api/quic_mem_slice.h" +#include "gquiche/common/platform/api/quiche_logging.h" using spdy::SpdyPriority; @@ -117,9 +120,12 @@ PendingStream::PendingStream(QuicStreamId id, QuicSession* session) stream_delegate_(session), stream_bytes_read_(0), fin_received_(false), + is_bidirectional_(QuicUtils::GetStreamType(id, session->perspective(), + /*peer_initiated = */ true, + session->version()) == + BIDIRECTIONAL), connection_flow_controller_(session->flow_controller()), - flow_controller_(session, - id, + flow_controller_(session, id, /*is_connection_flow_controller*/ false, GetReceivedFlowControlWindow(session, id), GetInitialStreamFlowControlWindowToSend(session, id), @@ -133,9 +139,7 @@ void PendingStream::OnDataAvailable() { // QuicSession::ProcessPendingStream() can read it. } -void PendingStream::OnFinRead() { - QUICHE_DCHECK(sequencer_.IsClosed()); -} +void PendingStream::OnFinRead() { QUICHE_DCHECK(sequencer_.IsClosed()); } void PendingStream::AddBytesConsumed(QuicByteCount bytes) { // It will be called when the metadata of the stream is consumed. @@ -143,7 +147,7 @@ void PendingStream::AddBytesConsumed(QuicByteCount bytes) { connection_flow_controller_->AddBytesConsumed(bytes); } -void PendingStream::Reset(QuicRstStreamErrorCode /*error*/) { +void PendingStream::ResetWithError(QuicResetStreamError /*error*/) { // Currently PendingStream is only read-unidirectional. It shouldn't send // Reset. QUIC_NOTREACHED(); @@ -160,13 +164,9 @@ void PendingStream::OnUnrecoverableError(QuicErrorCode error, stream_delegate_->OnStreamError(error, ietf_error, details); } -QuicStreamId PendingStream::id() const { - return id_; -} +QuicStreamId PendingStream::id() const { return id_; } -ParsedQuicVersion PendingStream::version() const { - return version_; -} +ParsedQuicVersion PendingStream::version() const { return version_; } void PendingStream::OnStreamFrame(const QuicStreamFrame& frame) { QUICHE_DCHECK_EQ(frame.stream_id, id_); @@ -251,6 +251,11 @@ void PendingStream::OnRstStreamFrame(const QuicRstStreamFrame& frame) { } } +void PendingStream::OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame) { + QUICHE_DCHECK(is_bidirectional_); + flow_controller_.UpdateSendWindowOffset(frame.max_data); +} + bool PendingStream::MaybeIncreaseHighestReceivedOffset( QuicStreamOffset new_offset) { uint64_t increment = @@ -267,6 +272,13 @@ bool PendingStream::MaybeIncreaseHighestReceivedOffset( return true; } +void PendingStream::OnStopSending( + QuicResetStreamError stop_sending_error_code) { + if (!stop_sending_error_code_) { + stop_sending_error_code_ = stop_sending_error_code; + } +} + void PendingStream::MarkConsumed(QuicByteCount num_bytes) { sequencer_.MarkConsumed(num_bytes); } @@ -276,19 +288,17 @@ void PendingStream::StopReading() { sequencer_.StopReading(); } -QuicStream::QuicStream(PendingStream* pending, - QuicSession* session, - StreamType type, +QuicStream::QuicStream(PendingStream* pending, QuicSession* session, bool is_static) - : QuicStream(pending->id_, - session, - std::move(pending->sequencer_), + : QuicStream(pending->id_, session, std::move(pending->sequencer_), is_static, - type, - pending->stream_bytes_read_, - pending->fin_received_, + QuicUtils::GetStreamType(pending->id_, session->perspective(), + /*peer_initiated = */ true, + session->version()), + pending->stream_bytes_read_, pending->fin_received_, std::move(pending->flow_controller_), pending->connection_flow_controller_) { + QUICHE_DCHECK(session->version().HasIetfQuicFrames()); sequencer_.set_stream(this); } @@ -316,26 +326,15 @@ absl::optional FlowController(QuicStreamId id, } // namespace -QuicStream::QuicStream(QuicStreamId id, - QuicSession* session, - bool is_static, +QuicStream::QuicStream(QuicStreamId id, QuicSession* session, bool is_static, StreamType type) - : QuicStream(id, - session, - QuicStreamSequencer(this), - is_static, - type, - 0, - false, - FlowController(id, session, type), + : QuicStream(id, session, QuicStreamSequencer(this), is_static, type, 0, + false, FlowController(id, session, type), session->flow_controller()) {} -QuicStream::QuicStream(QuicStreamId id, - QuicSession* session, - QuicStreamSequencer sequencer, - bool is_static, - StreamType type, - uint64_t stream_bytes_read, +QuicStream::QuicStream(QuicStreamId id, QuicSession* session, + QuicStreamSequencer sequencer, bool is_static, + StreamType type, uint64_t stream_bytes_read, bool fin_received, absl::optional flow_controller, QuicFlowController* connection_flow_controller) @@ -345,10 +344,11 @@ QuicStream::QuicStream(QuicStreamId id, stream_delegate_(session), precedence_(CalculateDefaultPriority(session)), stream_bytes_read_(stream_bytes_read), - stream_error_(QUIC_STREAM_NO_ERROR), + stream_error_(QuicResetStreamError::NoError()), connection_error_(QUIC_NO_ERROR), read_side_closed_(false), write_side_closed_(false), + write_side_data_recvd_state_notified_(false), fin_buffered_(false), fin_sent_(false), fin_outstanding_(false), @@ -370,8 +370,7 @@ QuicStream::QuicStream(QuicStreamId id, was_draining_(false), type_(VersionHasIetfQuicFrames(session->transport_version()) && type != CRYPTO - ? QuicUtils::GetStreamType(id_, - session->perspective(), + ? QuicUtils::GetStreamType(id_, session->perspective(), session->IsIncomingStream(id_), session->version()) : type), @@ -488,7 +487,7 @@ void QuicStream::OnStreamFrame(const QuicStreamFrame& frame) { sequencer_.OnStreamFrame(frame); } -bool QuicStream::OnStopSending(QuicRstStreamErrorCode code) { +bool QuicStream::OnStopSending(QuicResetStreamError error) { // Do not reset the stream if all data has been sent and acknowledged. if (write_side_closed() && !IsWaitingForAcks()) { QUIC_DVLOG(1) << ENDPOINT @@ -506,8 +505,14 @@ bool QuicStream::OnStopSending(QuicRstStreamErrorCode code) { return false; } - stream_error_ = code; - MaybeSendRstStream(code); + stream_error_ = error; + if (GetQuicReloadableFlag(quic_match_ietf_reset_code)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_match_ietf_reset_code); + MaybeSendRstStream(error); + } else { + MaybeSendRstStream( + QuicResetStreamError::FromInternal(error.internal_code())); + } return true; } @@ -552,7 +557,7 @@ void QuicStream::OnStreamReset(const QuicRstStreamFrame& frame) { return; } - stream_error_ = frame.error_code; + stream_error_ = frame.error(); // Google QUIC closes both sides of the stream in response to a // RESET_STREAM, IETF QUIC closes only the read side. if (!VersionHasIetfQuicFrames(transport_version())) { @@ -567,7 +572,8 @@ void QuicStream::OnConnectionClosed(QuicErrorCode error, return; } if (error != QUIC_NO_ERROR) { - stream_error_ = QUIC_STREAM_CONNECTION_ERROR; + stream_error_ = + QuicResetStreamError::FromInternal(QUIC_STREAM_CONNECTION_ERROR); connection_error_ = error; } @@ -592,6 +598,10 @@ void QuicStream::SetFinSent() { } void QuicStream::Reset(QuicRstStreamErrorCode error) { + ResetWithError(QuicResetStreamError::FromInternal(error)); +} + +void QuicStream::ResetWithError(QuicResetStreamError error) { stream_error_ = error; QuicConnection::ScopedPacketFlusher flusher(session()->connection()); MaybeSendStopSending(error); @@ -599,7 +609,24 @@ void QuicStream::Reset(QuicRstStreamErrorCode error) { if (read_side_closed_ && write_side_closed_ && !IsWaitingForAcks()) { session()->MaybeCloseZombieStream(id_); - return; + } +} + +void QuicStream::ResetWriteSide(QuicResetStreamError error) { + stream_error_ = error; + MaybeSendRstStream(error); + + if (read_side_closed_ && write_side_closed_ && !IsWaitingForAcks()) { + session()->MaybeCloseZombieStream(id_); + } +} + +void QuicStream::SendStopSending(QuicResetStreamError error) { + stream_error_ = error; + MaybeSendStopSending(error); + + if (read_side_closed_ && write_side_closed_ && !IsWaitingForAcks()) { + session()->MaybeCloseZombieStream(id_); } } @@ -627,26 +654,20 @@ void QuicStream::SetPriority(const spdy::SpdyStreamPrecedence& precedence) { } void QuicStream::WriteOrBufferData( - absl::string_view data, - bool fin, + absl::string_view data, bool fin, QuicReferenceCountedPointer ack_listener) { - if (session()->use_write_or_buffer_data_at_level()) { - QUIC_BUG_IF(quic_bug_12570_4, - QuicUtils::IsCryptoStreamId(transport_version(), id_)) - << ENDPOINT - << "WriteOrBufferData is used to send application data, use " - "WriteOrBufferDataAtLevel to send crypto data."; - return WriteOrBufferDataAtLevel( - data, fin, session()->GetEncryptionLevelToSendApplicationData(), - ack_listener); - } - return WriteOrBufferDataInner(data, fin, absl::nullopt, ack_listener); + QUIC_BUG_IF(quic_bug_12570_4, + QuicUtils::IsCryptoStreamId(transport_version(), id_)) + << ENDPOINT + << "WriteOrBufferData is used to send application data, use " + "WriteOrBufferDataAtLevel to send crypto data."; + return WriteOrBufferDataAtLevel( + data, fin, session()->GetEncryptionLevelToSendApplicationData(), + ack_listener); } -void QuicStream::WriteOrBufferDataInner( - absl::string_view data, - bool fin, - absl::optional level, +void QuicStream::WriteOrBufferDataAtLevel( + absl::string_view data, bool fin, EncryptionLevel level, QuicReferenceCountedPointer ack_listener) { if (data.empty() && !fin) { QUIC_BUG(quic_bug_10586_2) << "data.empty() && !fin"; @@ -691,16 +712,6 @@ void QuicStream::WriteOrBufferDataInner( } } -void QuicStream::WriteOrBufferDataAtLevel( - absl::string_view data, - bool fin, - EncryptionLevel level, - QuicReferenceCountedPointer ack_listener) { - QUICHE_DCHECK(session()->use_write_or_buffer_data_at_level()); - QUIC_RELOADABLE_FLAG_COUNT(quic_use_write_or_buffer_data_at_level); - return WriteOrBufferDataInner(data, fin, level, ack_listener); -} - void QuicStream::OnCanWrite() { if (HasDeadlinePassed()) { OnDeadlinePassed(); @@ -720,11 +731,7 @@ void QuicStream::OnCanWrite() { return; } if (HasBufferedData() || (fin_buffered_ && !fin_sent_)) { - absl::optional send_level = absl::nullopt; - if (session()->use_write_or_buffer_data_at_level()) { - send_level = session()->GetEncryptionLevelToSendApplicationData(); - } - WriteBufferedData(send_level); + WriteBufferedData(session()->GetEncryptionLevelToSendApplicationData()); } if (!fin_buffered_ && !fin_sent_ && CanWriteNewData()) { // Notify upper layer to write new data when buffered data size is below @@ -758,7 +765,12 @@ void QuicStream::MaybeSendBlocked() { } } -QuicConsumedData QuicStream::WriteMemSlices(QuicMemSliceSpan span, bool fin) { +QuicConsumedData QuicStream::WriteMemSlice(QuicMemSlice span, bool fin) { + return WriteMemSlices(absl::MakeSpan(&span, 1), fin); +} + +QuicConsumedData QuicStream::WriteMemSlices(absl::Span span, + bool fin) { QuicConsumedData consumed_data(0, false); if (span.empty() && !fin) { QUIC_BUG(quic_bug_10586_6) << "span.empty() && !fin"; @@ -802,11 +814,7 @@ QuicConsumedData QuicStream::WriteMemSlices(QuicMemSliceSpan span, bool fin) { if (!had_buffered_data && (HasBufferedData() || fin_buffered_)) { // Write data if there is no buffered data before. - absl::optional send_level = absl::nullopt; - if (session()->use_write_or_buffer_data_at_level()) { - send_level = session()->GetEncryptionLevelToSendApplicationData(); - } - WriteBufferedData(send_level); + WriteBufferedData(session()->GetEncryptionLevelToSendApplicationData()); } return consumed_data; @@ -853,12 +861,12 @@ void QuicStream::CloseWriteSide() { } } -void QuicStream::MaybeSendStopSending(QuicRstStreamErrorCode error) { +void QuicStream::MaybeSendStopSending(QuicResetStreamError error) { if (stop_sending_sent_) { return; } - if (!session()->version().UsesHttp3() && error != QUIC_STREAM_NO_ERROR) { + if (!session()->version().UsesHttp3() && !error.ok()) { // In gQUIC, RST with error closes both read and write side. return; } @@ -866,21 +874,21 @@ void QuicStream::MaybeSendStopSending(QuicRstStreamErrorCode error) { if (session()->version().UsesHttp3()) { session()->MaybeSendStopSendingFrame(id(), error); } else { - QUICHE_DCHECK_EQ(QUIC_STREAM_NO_ERROR, error); - session()->MaybeSendRstStreamFrame(id(), QUIC_STREAM_NO_ERROR, + QUICHE_DCHECK_EQ(QUIC_STREAM_NO_ERROR, error.internal_code()); + session()->MaybeSendRstStreamFrame(id(), QuicResetStreamError::NoError(), stream_bytes_written()); } stop_sending_sent_ = true; CloseReadSide(); } -void QuicStream::MaybeSendRstStream(QuicRstStreamErrorCode error) { +void QuicStream::MaybeSendRstStream(QuicResetStreamError error) { if (rst_sent_) { return; } if (!session()->version().UsesHttp3()) { - QUIC_BUG_IF(quic_bug_12570_5, error == QUIC_STREAM_NO_ERROR); + QUIC_BUG_IF(quic_bug_12570_5, error.ok()); stop_sending_sent_ = true; CloseReadSide(); } @@ -894,9 +902,7 @@ bool QuicStream::HasBufferedData() const { return send_buffer_.stream_offset() > stream_bytes_written(); } -ParsedQuicVersion QuicStream::version() const { - return session_->version(); -} +ParsedQuicVersion QuicStream::version() const { return session_->version(); } QuicTransportVersion QuicStream::transport_version() const { return session_->transport_version(); @@ -1077,8 +1083,7 @@ void QuicStream::AddRandomPaddingAfterFin() { } bool QuicStream::OnStreamFrameAcked(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin_acked, + QuicByteCount data_length, bool fin_acked, QuicTime::Delta /*ack_delay_time*/, QuicTime /*receive_timestamp*/, QuicByteCount* newly_acked_length) { @@ -1102,6 +1107,11 @@ bool QuicStream::OnStreamFrameAcked(QuicStreamOffset offset, fin_outstanding_ = false; fin_lost_ = false; } + if (!IsWaitingForAcks() && write_side_closed_ && + !write_side_data_recvd_state_notified_) { + OnWriteSideInDataRecvdState(); + write_side_data_recvd_state_notified_ = true; + } if (!IsWaitingForAcks() && read_side_closed_ && write_side_closed_) { session_->MaybeCloseZombieStream(id_); } @@ -1118,8 +1128,7 @@ void QuicStream::OnStreamFrameRetransmitted(QuicStreamOffset offset, } void QuicStream::OnStreamFrameLost(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin_lost) { + QuicByteCount data_length, bool fin_lost) { QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " Losting " << "[" << offset << ", " << offset + data_length << "]" << " fin = " << fin_lost; @@ -1132,8 +1141,7 @@ void QuicStream::OnStreamFrameLost(QuicStreamOffset offset, } bool QuicStream::RetransmitStreamData(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin, + QuicByteCount data_length, bool fin, TransmissionType type) { QUICHE_DCHECK(type == PTO_RETRANSMISSION || type == RTO_RETRANSMISSION || type == TLP_RETRANSMISSION || type == PROBING_RETRANSMISSION); @@ -1148,10 +1156,6 @@ bool QuicStream::RetransmitStreamData(QuicStreamOffset offset, if (retransmission.Empty() && !retransmit_fin) { return true; } - absl::optional send_level = absl::nullopt; - if (session()->use_write_or_buffer_data_at_level()) { - send_level = session()->GetEncryptionLevelToSendApplicationData(); - } QuicConsumedData consumed(0, false); for (const auto& interval : retransmission) { QuicStreamOffset retransmission_offset = interval.min(); @@ -1161,7 +1165,8 @@ bool QuicStream::RetransmitStreamData(QuicStreamOffset offset, stream_bytes_written()); consumed = stream_delegate_->WritevData( id_, retransmission_length, retransmission_offset, - can_bundle_fin ? FIN : NO_FIN, type, send_level); + can_bundle_fin ? FIN : NO_FIN, type, + session()->GetEncryptionLevelToSendApplicationData()); QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " is forced to retransmit stream data [" << retransmission_offset << ", " @@ -1182,8 +1187,9 @@ bool QuicStream::RetransmitStreamData(QuicStreamOffset offset, if (retransmit_fin) { QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " retransmits fin only frame."; - consumed = stream_delegate_->WritevData(id_, 0, stream_bytes_written(), FIN, - type, send_level); + consumed = stream_delegate_->WritevData( + id_, 0, stream_bytes_written(), FIN, type, + session()->GetEncryptionLevelToSendApplicationData()); if (!consumed.fin_consumed) { return false; } @@ -1192,7 +1198,7 @@ bool QuicStream::RetransmitStreamData(QuicStreamOffset offset, } bool QuicStream::IsWaitingForAcks() const { - return (!rst_sent_ || stream_error_ == QUIC_STREAM_NO_ERROR) && + return (!rst_sent_ || stream_error_.ok()) && (send_buffer_.stream_bytes_outstanding() || fin_outstanding_); } @@ -1209,7 +1215,7 @@ bool QuicStream::WriteStreamData(QuicStreamOffset offset, return send_buffer_.WriteStreamData(offset, data_length, writer); } -void QuicStream::WriteBufferedData(absl::optional level) { +void QuicStream::WriteBufferedData(EncryptionLevel level) { QUICHE_DCHECK(!write_side_closed_ && (HasBufferedData() || fin_buffered_)); if (session_->ShouldYield(id())) { @@ -1333,15 +1339,12 @@ void QuicStream::OnStreamDataConsumed(QuicByteCount bytes_consumed) { void QuicStream::WritePendingRetransmission() { while (HasPendingRetransmission()) { QuicConsumedData consumed(0, false); - absl::optional send_level = absl::nullopt; - if (session()->use_write_or_buffer_data_at_level()) { - send_level = session()->GetEncryptionLevelToSendApplicationData(); - } if (!send_buffer_.HasPendingRetransmission()) { QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " retransmits fin only frame."; consumed = stream_delegate_->WritevData( - id_, 0, stream_bytes_written(), FIN, LOSS_RETRANSMISSION, send_level); + id_, 0, stream_bytes_written(), FIN, LOSS_RETRANSMISSION, + session()->GetEncryptionLevelToSendApplicationData()); fin_lost_ = !consumed.fin_consumed; if (fin_lost_) { // Connection is write blocked. @@ -1356,7 +1359,8 @@ void QuicStream::WritePendingRetransmission() { (pending.offset + pending.length == stream_bytes_written()); consumed = stream_delegate_->WritevData( id_, pending.length, pending.offset, can_bundle_fin ? FIN : NO_FIN, - LOSS_RETRANSMISSION, send_level); + LOSS_RETRANSMISSION, + session()->GetEncryptionLevelToSendApplicationData()); QUIC_DVLOG(1) << ENDPOINT << "stream " << id_ << " tries to retransmit stream data [" << pending.offset << ", " << pending.offset + pending.length @@ -1401,9 +1405,7 @@ bool QuicStream::HasDeadlinePassed() const { return true; } -void QuicStream::OnDeadlinePassed() { - Reset(QUIC_STREAM_TTL_EXPIRED); -} +void QuicStream::OnDeadlinePassed() { Reset(QUIC_STREAM_TTL_EXPIRED); } bool QuicStream::IsFlowControlBlocked() const { if (!flow_controller_.has_value()) { diff --git a/gquiche/quic/core/quic_stream.h b/gquiche/quic/core/quic_stream.h index 1af0bf0f..36669f97 100644 --- a/gquiche/quic/core/quic_stream.h +++ b/gquiche/quic/core/quic_stream.h @@ -24,6 +24,9 @@ #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "absl/types/span.h" +#include "gquiche/quic/core/frames/quic_rst_stream_frame.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_flow_controller.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_stream_send_buffer.h" @@ -32,7 +35,7 @@ #include "gquiche/quic/core/session_notifier_interface.h" #include "gquiche/quic/core/stream_delegate_interface.h" #include "gquiche/quic/platform/api/quic_export.h" -#include "gquiche/quic/platform/api/quic_mem_slice_span.h" +#include "gquiche/quic/platform/api/quic_mem_slice.h" #include "gquiche/quic/platform/api/quic_reference_counted.h" #include "gquiche/spdy/core/spdy_protocol.h" @@ -58,7 +61,7 @@ class QUIC_EXPORT_PRIVATE PendingStream void OnDataAvailable() override; void OnFinRead() override; void AddBytesConsumed(QuicByteCount bytes) override; - void Reset(QuicRstStreamErrorCode error) override; + void ResetWithError(QuicResetStreamError error) override; void OnUnrecoverableError(QuicErrorCode error, const std::string& details) override; void OnUnrecoverableError(QuicErrorCode error, @@ -71,10 +74,21 @@ class QUIC_EXPORT_PRIVATE PendingStream // If the data violates flow control, the connection will be closed. void OnStreamFrame(const QuicStreamFrame& frame); + bool is_bidirectional() const { return is_bidirectional_; } + // Stores the final byte offset from |frame|. // If the final offset violates flow control, the connection will be closed. void OnRstStreamFrame(const QuicRstStreamFrame& frame); + void OnWindowUpdateFrame(const QuicWindowUpdateFrame& frame); + + void OnStopSending(QuicResetStreamError stop_sending_error_code); + + // The error code received from QuicStopSendingFrame (if any). + const absl::optional& GetStopSendingErrorCode() const { + return stop_sending_error_code_; + } + // Returns the number of bytes read on this stream. uint64_t stream_bytes_read() { return stream_bytes_read_; } @@ -107,12 +121,17 @@ class QUIC_EXPORT_PRIVATE PendingStream // True if a frame containing a fin has been received. bool fin_received_; + // True if this pending stream is backing a bidirectional stream. + bool is_bidirectional_; + // Connection-level flow controller. Owned by the session. QuicFlowController* connection_flow_controller_; // Stream-level flow controller. QuicFlowController flow_controller_; // Stores the buffered frames. QuicStreamSequencer sequencer_; + // The error code received from QuicStopSendingFrame (if any). + absl::optional stop_sending_error_code_; }; class QUIC_EXPORT_PRIVATE QuicStream @@ -134,14 +153,9 @@ class QUIC_EXPORT_PRIVATE QuicStream // |type| indicates whether the stream is bidirectional, read unidirectional // or write unidirectional. // TODO(fayang): Remove |type| when IETF stream ID numbering fully kicks in. - QuicStream(QuicStreamId id, - QuicSession* session, - bool is_static, + QuicStream(QuicStreamId id, QuicSession* session, bool is_static, StreamType type); - QuicStream(PendingStream* pending, - QuicSession* session, - StreamType type, - bool is_static); + QuicStream(PendingStream* pending, QuicSession* session, bool is_static); QuicStream(const QuicStream&) = delete; QuicStream& operator=(const QuicStream&) = delete; @@ -160,7 +174,16 @@ class QUIC_EXPORT_PRIVATE QuicStream // Called by the subclass or the sequencer to reset the stream from this // end. - void Reset(QuicRstStreamErrorCode error) override; + void ResetWithError(QuicResetStreamError error) override; + // Convenience wrapper for the method above. + // TODO(b/200606367): switch all calls to using QuicResetStreamError + // interface. + void Reset(QuicRstStreamErrorCode error); + + // Reset() sends both RESET_STREAM and STOP_SENDING; the two methods below + // allow to send only one of those. + void ResetWriteSide(QuicResetStreamError error); + void SendStopSending(QuicResetStreamError error); // Called by the subclass or the sequencer to close the entire connection from // this end. @@ -208,7 +231,9 @@ class QUIC_EXPORT_PRIVATE QuicStream // Number of bytes available to read. QuicByteCount ReadableBytes() const; - QuicRstStreamErrorCode stream_error() const { return stream_error_; } + QuicRstStreamErrorCode stream_error() const { + return stream_error_.internal_code(); + } QuicErrorCode connection_error() const { return connection_error_; } bool reading_stopped() const { @@ -290,31 +315,26 @@ class QUIC_EXPORT_PRIVATE QuicStream // session, write_side_closed() becomes true, otherwise fin_buffered_ becomes // true. void WriteOrBufferData( - absl::string_view data, - bool fin, + absl::string_view data, bool fin, QuicReferenceCountedPointer ack_listener); // Sends |data| to connection with specified |level|. void WriteOrBufferDataAtLevel( - absl::string_view data, - bool fin, - EncryptionLevel level, + absl::string_view data, bool fin, EncryptionLevel level, QuicReferenceCountedPointer ack_listener); // Adds random padding after the fin is consumed for this stream. void AddRandomPaddingAfterFin(); // Write |data_length| of data starts at |offset| from send buffer. - bool WriteStreamData(QuicStreamOffset offset, - QuicByteCount data_length, + bool WriteStreamData(QuicStreamOffset offset, QuicByteCount data_length, QuicDataWriter* writer); // Called when data [offset, offset + data_length) is acked. |fin_acked| // indicates whether the fin is acked. Returns true and updates // |newly_acked_length| if any new stream data (including fin) gets acked. virtual bool OnStreamFrameAcked(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin_acked, + QuicByteCount data_length, bool fin_acked, QuicTime::Delta ack_delay_time, QuicTime receive_timestamp, QuicByteCount* newly_acked_length); @@ -328,24 +348,25 @@ class QUIC_EXPORT_PRIVATE QuicStream // Called when data [offset, offset + data_length) is considered as lost. // |fin_lost| indicates whether the fin is considered as lost. virtual void OnStreamFrameLost(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin_lost); + QuicByteCount data_length, bool fin_lost); // Called to retransmit outstanding portion in data [offset, offset + // data_length) and |fin| with Transmission |type|. // Returns true if all data gets retransmitted. virtual bool RetransmitStreamData(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin, + QuicByteCount data_length, bool fin, TransmissionType type); // Sets deadline of this stream to be now + |ttl|, returns true if the setting // succeeds. bool MaybeSetTtl(QuicTime::Delta ttl); - // Same as WritevData except data is provided in reference counted memory so - // that data copy is avoided. - QuicConsumedData WriteMemSlices(QuicMemSliceSpan span, bool fin); + // Commits data into the stream write buffer, and potentially sends it over + // the wire. This method has all-or-nothing semantics: if the write buffer is + // not full, all of the memslices in |span| are moved into it; otherwise, + // nothing happens. + QuicConsumedData WriteMemSlices(absl::Span span, bool fin); + QuicConsumedData WriteMemSlice(QuicMemSlice span, bool fin); // Returns true if any stream data is lost (including fin) and needs to be // retransmitted. @@ -355,14 +376,13 @@ class QUIC_EXPORT_PRIVATE QuicStream // outstanding or fin is outstanding (if |fin| is true). Returns false // otherwise. bool IsStreamFrameOutstanding(QuicStreamOffset offset, - QuicByteCount data_length, - bool fin) const; + QuicByteCount data_length, bool fin) const; StreamType type() const { return type_; } // Handle received StopSending frame. Returns true if the processing finishes // gracefully. - virtual bool OnStopSending(QuicRstStreamErrorCode code); + virtual bool OnStopSending(QuicResetStreamError error); // Returns true if the stream is static. bool is_static() const { return is_static_; } @@ -387,8 +407,7 @@ class QUIC_EXPORT_PRIVATE QuicStream // Called when data of [offset, offset + data_length] is buffered in send // buffer. virtual void OnDataBuffered( - QuicStreamOffset /*offset*/, - QuicByteCount /*data_length*/, + QuicStreamOffset /*offset*/, QuicByteCount /*data_length*/, const QuicReferenceCountedPointer& /*ack_listener*/) {} @@ -425,10 +444,18 @@ class QUIC_EXPORT_PRIVATE QuicStream void SetFinSent(); // Send STOP_SENDING if it hasn't been sent yet. - void MaybeSendStopSending(QuicRstStreamErrorCode error); + void MaybeSendStopSending(QuicResetStreamError error); // Send RESET_STREAM if it hasn't been sent yet. - void MaybeSendRstStream(QuicRstStreamErrorCode error); + void MaybeSendRstStream(QuicResetStreamError error); + + // Convenience warppers for two methods above. + void MaybeSendRstStream(QuicRstStreamErrorCode error) { + MaybeSendRstStream(QuicResetStreamError::FromInternal(error)); + } + void MaybeSendStopSending(QuicRstStreamErrorCode error) { + MaybeSendStopSending(QuicResetStreamError::FromInternal(error)); + } // Close the write side of the socket. Further writes will fail. // Can be called by the subclass or internally. @@ -436,7 +463,7 @@ class QUIC_EXPORT_PRIVATE QuicStream virtual void CloseWriteSide(); void set_rst_received(bool rst_received) { rst_received_ = rst_received; } - void set_stream_error(QuicRstStreamErrorCode error) { stream_error_ = error; } + void set_stream_error(QuicResetStreamError error) { stream_error_ = error; } StreamDelegateInterface* stream_delegate() { return stream_delegate_; } @@ -456,6 +483,11 @@ class QUIC_EXPORT_PRIVATE QuicStream QuicStreamSendBuffer& send_buffer() { return send_buffer_; } + // Called when the write side of the stream is closed, and all of the outgoing + // data has been acknowledged. This corresponds to the "Data Recvd" state of + // RFC 9000. + virtual void OnWriteSideInDataRecvdState() {} + // Return the current flow control send window in bytes. absl::optional GetSendWindow() const; absl::optional GetReceiveWindow() const; @@ -464,13 +496,9 @@ class QUIC_EXPORT_PRIVATE QuicStream friend class test::QuicStreamPeer; friend class QuicStreamUtils; - QuicStream(QuicStreamId id, - QuicSession* session, - QuicStreamSequencer sequencer, - bool is_static, - StreamType type, - uint64_t stream_bytes_read, - bool fin_received, + QuicStream(QuicStreamId id, QuicSession* session, + QuicStreamSequencer sequencer, bool is_static, StreamType type, + uint64_t stream_bytes_read, bool fin_received, absl::optional flow_controller, QuicFlowController* connection_flow_controller); @@ -480,10 +508,8 @@ class QUIC_EXPORT_PRIVATE QuicStream // controller, marks this stream as connection-level write blocked. void MaybeSendBlocked(); - // Write buffered data in send buffer. - // TODO(fayang): Change absl::optional to EncryptionLevel - // when deprecating quic_use_write_or_buffer_data_at_level. - void WriteBufferedData(absl::optional level); + // Write buffered data (in send buffer) at |level|. + void WriteBufferedData(EncryptionLevel level); // Close the read side of the stream. May cause the stream to be closed. void CloseReadSide(); @@ -491,14 +517,6 @@ class QUIC_EXPORT_PRIVATE QuicStream // Called when bytes are sent to the peer. void AddBytesSent(QuicByteCount bytes); - // TODO(fayang): Inline this function when deprecating - // quic_use_write_or_buffer_data_at_level. - void WriteOrBufferDataInner( - absl::string_view data, - bool fin, - absl::optional level, - QuicReferenceCountedPointer ack_listener); - // Returns true if deadline_ has passed. bool HasDeadlinePassed() const; @@ -516,7 +534,7 @@ class QUIC_EXPORT_PRIVATE QuicStream // Stream error code received from a RstStreamFrame or error code sent by the // visitor or sequencer in the RstStreamFrame. - QuicRstStreamErrorCode stream_error_; + QuicResetStreamError stream_error_; // Connection error code due to which the stream was closed. |stream_error_| // is set to |QUIC_STREAM_CONNECTION_ERROR| when this happens and consumers // should check |connection_error_|. @@ -527,6 +545,9 @@ class QUIC_EXPORT_PRIVATE QuicStream // True if the write side is closed, and further writes should fail. bool write_side_closed_; + // True if OnWriteSideInDataRecvdState() has already been called. + bool write_side_data_recvd_state_notified_; + // True if the subclass has written a FIN with WriteOrBufferData, but it was // buffered in queued_data_ rather than being sent to the session. bool fin_buffered_; diff --git a/gquiche/quic/core/quic_stream_id_manager.cc b/gquiche/quic/core/quic_stream_id_manager.cc index 3cdf399d..b0ed0304 100644 --- a/gquiche/quic/core/quic_stream_id_manager.cc +++ b/gquiche/quic/core/quic_stream_id_manager.cc @@ -217,7 +217,7 @@ bool QuicStreamIdManager::IsAvailableStream(QuicStreamId id) const { return largest_peer_created_stream_id_ == QuicUtils::GetInvalidStreamId(version_.transport_version) || id > largest_peer_created_stream_id_ || - QuicContainsKey(available_streams_, id); + available_streams_.contains(id); } QuicStreamId QuicStreamIdManager::GetFirstOutgoingStreamId() const { diff --git a/gquiche/quic/core/quic_stream_send_buffer.cc b/gquiche/quic/core/quic_stream_send_buffer.cc index 3c2dbbe9..c62b0758 100644 --- a/gquiche/quic/core/quic_stream_send_buffer.cc +++ b/gquiche/quic/core/quic_stream_send_buffer.cc @@ -12,6 +12,7 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/quic/platform/api/quic_mem_slice.h" namespace quic { @@ -91,9 +92,18 @@ void QuicStreamSendBuffer::SaveMemSlice(QuicMemSlice slice) { stream_offset_ += length; } -QuicByteCount QuicStreamSendBuffer::SaveMemSliceSpan(QuicMemSliceSpan span) { - return span.ConsumeAll( - [&](QuicMemSlice slice) { SaveMemSlice(std::move(slice)); }); +QuicByteCount QuicStreamSendBuffer::SaveMemSliceSpan( + absl::Span span) { + QuicByteCount total = 0; + for (QuicMemSlice& slice : span) { + if (slice.length() == 0) { + // Skip empty slices. + continue; + } + total += slice.length(); + SaveMemSlice(std::move(slice)); + } + return total; } void QuicStreamSendBuffer::OnStreamDataConsumed(size_t bytes_consumed) { diff --git a/gquiche/quic/core/quic_stream_send_buffer.h b/gquiche/quic/core/quic_stream_send_buffer.h index 5c58a0c7..1c02e504 100644 --- a/gquiche/quic/core/quic_stream_send_buffer.h +++ b/gquiche/quic/core/quic_stream_send_buffer.h @@ -5,14 +5,14 @@ #ifndef QUICHE_QUIC_CORE_QUIC_STREAM_SEND_BUFFER_H_ #define QUICHE_QUIC_CORE_QUIC_STREAM_SEND_BUFFER_H_ +#include "absl/types/span.h" #include "gquiche/quic/core/frames/quic_stream_frame.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_interval_deque.h" #include "gquiche/quic/core/quic_interval_set.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_iovec.h" #include "gquiche/quic/platform/api/quic_mem_slice.h" -#include "gquiche/quic/platform/api/quic_mem_slice_span.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -79,7 +79,7 @@ class QUIC_EXPORT_PRIVATE QuicStreamSendBuffer { void SaveMemSlice(QuicMemSlice slice); // Save all slices in |span| to send buffer. Return total bytes saved. - QuicByteCount SaveMemSliceSpan(QuicMemSliceSpan span); + QuicByteCount SaveMemSliceSpan(absl::Span span); // Called when |bytes_consumed| bytes has been consumed by the stream. void OnStreamDataConsumed(size_t bytes_consumed); diff --git a/gquiche/quic/core/quic_stream_send_buffer_test.cc b/gquiche/quic/core/quic_stream_send_buffer_test.cc index fcf2f105..6c1e6cfb 100644 --- a/gquiche/quic/core/quic_stream_send_buffer_test.cc +++ b/gquiche/quic/core/quic_stream_send_buffer_test.cc @@ -13,7 +13,6 @@ #include "gquiche/quic/platform/api/quic_expect_bug.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/quic/platform/api/quic_test_mem_slice_vector.h" #include "gquiche/quic/test_tools/quic_stream_send_buffer_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" @@ -39,12 +38,12 @@ class QuicStreamSendBufferTest : public QuicTest { iov[0] = MakeIovec(absl::string_view(data1)); iov[1] = MakeIovec(absl::string_view(data2)); - QuicUniqueBufferPtr buffer1 = MakeUniqueBuffer(&allocator_, 1024); - memset(buffer1.get(), 'c', 1024); - QuicMemSlice slice1(std::move(buffer1), 1024); - QuicUniqueBufferPtr buffer2 = MakeUniqueBuffer(&allocator_, 768); - memset(buffer2.get(), 'd', 768); - QuicMemSlice slice2(std::move(buffer2), 768); + QuicBuffer buffer1(&allocator_, 1024); + memset(buffer1.data(), 'c', buffer1.size()); + QuicMemSlice slice1(std::move(buffer1)); + QuicBuffer buffer2(&allocator_, 768); + memset(buffer2.data(), 'd', buffer2.size()); + QuicMemSlice slice2(std::move(buffer2)); // The stream offset should be 0 since nothing is written. EXPECT_EQ(0u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); @@ -309,9 +308,9 @@ TEST_F(QuicStreamSendBufferTest, EndOffset) { // Last offset is end offset of last slice. EXPECT_EQ(3840u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); - QuicUniqueBufferPtr buffer = MakeUniqueBuffer(&allocator_, 60); - memset(buffer.get(), 'e', 60); - QuicMemSlice slice(std::move(buffer), 60); + QuicBuffer buffer(&allocator_, 60); + memset(buffer.data(), 'e', buffer.size()); + QuicMemSlice slice(std::move(buffer)); send_buffer_.SaveMemSlice(std::move(slice)); EXPECT_EQ(3840u, QuicStreamSendBufferPeer::EndOffset(&send_buffer_)); @@ -321,14 +320,13 @@ TEST_F(QuicStreamSendBufferTest, SaveMemSliceSpan) { SimpleBufferAllocator allocator; QuicStreamSendBuffer send_buffer(&allocator); - char data[1024]; - std::vector> buffers; + std::string data(1024, 'a'); + std::vector buffers; for (size_t i = 0; i < 10; ++i) { - buffers.push_back(std::make_pair(data, 1024)); + buffers.push_back(MemSliceFromString(data)); } - QuicTestMemSliceVector vector(buffers); - EXPECT_EQ(10 * 1024u, send_buffer.SaveMemSliceSpan(vector.span())); + EXPECT_EQ(10 * 1024u, send_buffer.SaveMemSliceSpan(absl::MakeSpan(buffers))); EXPECT_EQ(10u, send_buffer.size()); } @@ -336,15 +334,13 @@ TEST_F(QuicStreamSendBufferTest, SaveEmptyMemSliceSpan) { SimpleBufferAllocator allocator; QuicStreamSendBuffer send_buffer(&allocator); - char data[1024]; - std::vector> buffers; + std::string data(1024, 'a'); + std::vector buffers; for (size_t i = 0; i < 10; ++i) { - buffers.push_back(std::make_pair(data, 1024)); + buffers.push_back(MemSliceFromString(data)); } - buffers.push_back(std::make_pair(nullptr, 0)); - QuicTestMemSliceVector vector(buffers); - EXPECT_EQ(10 * 1024u, send_buffer.SaveMemSliceSpan(vector.span())); + EXPECT_EQ(10 * 1024u, send_buffer.SaveMemSliceSpan(absl::MakeSpan(buffers))); // Verify the empty slice does not get saved. EXPECT_EQ(10u, send_buffer.size()); } diff --git a/gquiche/quic/core/quic_stream_sequencer.cc b/gquiche/quic/core/quic_stream_sequencer.cc index 2a2fb130..f8614ea9 100644 --- a/gquiche/quic/core/quic_stream_sequencer.cc +++ b/gquiche/quic/core/quic_stream_sequencer.cc @@ -56,19 +56,23 @@ void QuicStreamSequencer::OnStreamFrame(const QuicStreamFrame& frame) { (!CloseStreamAtOffset(frame.offset + data_len) || data_len == 0)) { return; } - if (GetQuicReloadableFlag(quic_accept_empty_stream_frame_with_no_fin)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_accept_empty_stream_frame_with_no_fin); - if (stream_->version().HasIetfQuicFrames() && data_len == 0) { - QUICHE_DCHECK(!frame.fin); - // Ignore empty frame with no fin. - return; - } + if (stream_->version().HasIetfQuicFrames() && data_len == 0) { + QUICHE_DCHECK(!frame.fin); + // Ignore empty frame with no fin. + return; } OnFrameData(byte_offset, data_len, frame.data_buffer); } void QuicStreamSequencer::OnCryptoFrame(const QuicCryptoFrame& frame) { ++num_frames_received_; + if (GetQuicReloadableFlag(quic_accept_empty_crypto_frame)) { + QUIC_RELOADABLE_FLAG_COUNT(quic_accept_empty_crypto_frame); + if (frame.data_length == 0) { + // Ignore empty crypto frame. + return; + } + } OnFrameData(frame.offset, frame.data_length, frame.data_buffer); } @@ -239,7 +243,8 @@ void QuicStreamSequencer::MarkConsumed(size_t num_bytes_consumed) { << "Invalid argument to MarkConsumed." << " expect to consume: " << num_bytes_consumed << ", but not enough bytes available. " << DebugString(); - stream_->Reset(QUIC_ERROR_PROCESSING_STREAM); + stream_->ResetWithError( + QuicResetStreamError::FromInternal(QUIC_ERROR_PROCESSING_STREAM)); return; } stream_->AddBytesConsumed(num_bytes_consumed); diff --git a/gquiche/quic/core/quic_stream_sequencer.h b/gquiche/quic/core/quic_stream_sequencer.h index e50009a9..6aab90d9 100644 --- a/gquiche/quic/core/quic_stream_sequencer.h +++ b/gquiche/quic/core/quic_stream_sequencer.h @@ -37,7 +37,7 @@ class QUIC_EXPORT_PRIVATE QuicStreamSequencer { virtual void AddBytesConsumed(QuicByteCount bytes) = 0; // Called when an error has occurred which should result in the stream // being reset. - virtual void Reset(QuicRstStreamErrorCode error) = 0; + virtual void ResetWithError(QuicResetStreamError error) = 0; // Called when an error has occurred which should result in the connection // being closed. virtual void OnUnrecoverableError(QuicErrorCode error, diff --git a/gquiche/quic/core/quic_stream_sequencer_buffer.cc b/gquiche/quic/core/quic_stream_sequencer_buffer.cc index 43d0030f..cdb73bb2 100644 --- a/gquiche/quic/core/quic_stream_sequencer_buffer.cc +++ b/gquiche/quic/core/quic_stream_sequencer_buffer.cc @@ -46,9 +46,7 @@ QuicStreamSequencerBuffer::QuicStreamSequencerBuffer(size_t max_capacity_bytes) current_blocks_count_(0u), total_bytes_read_(0), blocks_(nullptr) { - if (allocate_blocks_on_demand_) { - QUICHE_DCHECK_GE(max_blocks_count_, kInitialBlockCount); - } + QUICHE_DCHECK_GE(max_blocks_count_, kInitialBlockCount); Clear(); } @@ -58,9 +56,7 @@ QuicStreamSequencerBuffer::~QuicStreamSequencerBuffer() { void QuicStreamSequencerBuffer::Clear() { if (blocks_ != nullptr) { - size_t blocks_to_clear = - allocate_blocks_on_demand_ ? current_blocks_count_ : max_blocks_count_; - for (size_t i = 0; i < blocks_to_clear; ++i) { + for (size_t i = 0; i < current_blocks_count_; ++i) { if (blocks_[i] != nullptr) { RetireBlock(i); } @@ -129,9 +125,7 @@ QuicErrorCode QuicStreamSequencerBuffer::OnStreamData( *error_details = "Received data beyond available range."; return QUIC_INTERNAL_ERROR; } - if (allocate_blocks_on_demand_) { - QUIC_RELOADABLE_FLAG_COUNT( - quic_allocate_stream_sequencer_buffer_blocks_on_demand); + if (!delay_allocation_until_new_data_) { MaybeAddMoreBlocks(starting_offset + size); } @@ -148,6 +142,11 @@ QuicErrorCode QuicStreamSequencerBuffer::OnStreamData( *error_details = "Too many data intervals received for this stream."; return QUIC_TOO_MANY_STREAM_DATA_INTERVALS; } + if (delay_allocation_until_new_data_) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_delay_sequencer_buffer_allocation_until_new_data, 1, 2); + MaybeAddMoreBlocks(starting_offset + size); + } size_t bytes_copy = 0; if (!CopyStreamData(starting_offset, data, &bytes_copy, error_details)) { @@ -171,6 +170,11 @@ QuicErrorCode QuicStreamSequencerBuffer::OnStreamData( *error_details = "Too many data intervals received for this stream."; return QUIC_TOO_MANY_STREAM_DATA_INTERVALS; } + if (delay_allocation_until_new_data_) { + QUIC_RELOADABLE_FLAG_COUNT_N( + quic_delay_sequencer_buffer_allocation_until_new_data, 2, 2); + MaybeAddMoreBlocks(starting_offset + size); + } for (const auto& interval : newly_received) { const QuicStreamOffset copy_offset = interval.min(); const QuicByteCount copy_length = interval.max() - interval.min(); @@ -202,8 +206,7 @@ bool QuicStreamSequencerBuffer::CopyStreamData(QuicStreamOffset offset, while (source_remaining > 0) { const size_t write_block_num = GetBlockIndex(offset); const size_t write_block_offset = GetInBlockOffset(offset); - size_t current_blocks_count = - allocate_blocks_on_demand_ ? current_blocks_count_ : max_blocks_count_; + size_t current_blocks_count = current_blocks_count_; QUICHE_DCHECK_GT(current_blocks_count, write_block_num); size_t block_capacity = GetBlockCapacity(write_block_num); @@ -215,15 +218,6 @@ bool QuicStreamSequencerBuffer::CopyStreamData(QuicStreamOffset offset, bytes_avail = total_bytes_read_ + max_buffer_capacity_bytes_ - offset; } - if (!allocate_blocks_on_demand_) { - if (blocks_ == nullptr) { - blocks_.reset(new BufferBlock*[max_blocks_count_]()); - for (size_t i = 0; i < max_blocks_count_; ++i) { - blocks_[i] = nullptr; - } - } - } - if (write_block_num >= current_blocks_count) { *error_details = absl::StrCat( "QuicStreamSequencerBuffer error: OnStreamData() exceed array bounds." diff --git a/gquiche/quic/core/quic_stream_sequencer_buffer.h b/gquiche/quic/core/quic_stream_sequencer_buffer.h index 8b29a9d9..6a8c02ff 100644 --- a/gquiche/quic/core/quic_stream_sequencer_buffer.h +++ b/gquiche/quic/core/quic_stream_sequencer_buffer.h @@ -229,10 +229,6 @@ class QUIC_EXPORT_PRIVATE QuicStreamSequencerBuffer { // Number of bytes read out of buffer. QuicStreamOffset total_bytes_read_; - // Whether size of blocks_ grows on demand. - bool allocate_blocks_on_demand_ = GetQuicReloadableFlag( - quic_allocate_stream_sequencer_buffer_blocks_on_demand); - // An ordered, variable-length list of blocks, with the length limited // such that the number of blocks never exceeds max_blocks_count_. // Each list entry can hold up to kBlockSizeBytes bytes. @@ -243,6 +239,9 @@ class QUIC_EXPORT_PRIVATE QuicStreamSequencerBuffer { // Currently received data. QuicIntervalSet bytes_received_; + + bool delay_allocation_until_new_data_ = GetQuicReloadableFlag( + quic_delay_sequencer_buffer_allocation_until_new_data); }; } // namespace quic diff --git a/gquiche/quic/core/quic_stream_sequencer_buffer_test.cc b/gquiche/quic/core/quic_stream_sequencer_buffer_test.cc index 68f321a1..0f385201 100644 --- a/gquiche/quic/core/quic_stream_sequencer_buffer_test.cc +++ b/gquiche/quic/core/quic_stream_sequencer_buffer_test.cc @@ -1091,8 +1091,6 @@ TEST_F(QuicStreamSequencerBufferRandomIOTest, RandomWriteAndConsumeInPlace) { } TEST_F(QuicStreamSequencerBufferTest, GrowBlockSizeOnDemand) { - SetQuicReloadableFlag(quic_allocate_stream_sequencer_buffer_blocks_on_demand, - true); max_capacity_bytes_ = 1024 * kBlockSizeBytes; std::string source_of_one_block(kBlockSizeBytes, 'a'); Initialize(); diff --git a/gquiche/quic/core/quic_stream_sequencer_test.cc b/gquiche/quic/core/quic_stream_sequencer_test.cc index 3ef10a75..280a876c 100644 --- a/gquiche/quic/core/quic_stream_sequencer_test.cc +++ b/gquiche/quic/core/quic_stream_sequencer_test.cc @@ -43,7 +43,7 @@ class MockStream : public QuicStreamSequencer::StreamInterface { QuicIetfTransportErrorCodes ietf_error, const std::string& details), (override)); - MOCK_METHOD(void, Reset, (QuicRstStreamErrorCode error), (override)); + MOCK_METHOD(void, ResetWithError, (QuicResetStreamError error), (override)); MOCK_METHOD(void, AddBytesConsumed, (QuicByteCount bytes), (override)); QuicStreamId id() const override { return 1; } @@ -252,8 +252,7 @@ TEST_F(QuicStreamSequencerTest, BlockedThenFullFrameAndFinConsumed) { } TEST_F(QuicStreamSequencerTest, EmptyFrame) { - if (!GetQuicReloadableFlag(quic_accept_empty_stream_frame_with_no_fin) || - !stream_.version().HasIetfQuicFrames()) { + if (!stream_.version().HasIetfQuicFrames()) { EXPECT_CALL(stream_, OnUnrecoverableError(QUIC_EMPTY_STREAM_FRAME_NO_FIN, _)); } @@ -567,7 +566,8 @@ TEST_F(QuicStreamSequencerTest, MarkConsumedError) { // Now, attempt to mark consumed more data than was readable and expect the // stream to be closed. - EXPECT_CALL(stream_, Reset(QUIC_ERROR_PROCESSING_STREAM)); + EXPECT_CALL(stream_, ResetWithError(QuicResetStreamError::FromInternal( + QUIC_ERROR_PROCESSING_STREAM))); EXPECT_QUIC_BUG(sequencer_->MarkConsumed(4), "Invalid argument to MarkConsumed." " expect to consume: 4, but not enough bytes available."); diff --git a/gquiche/quic/core/quic_stream_test.cc b/gquiche/quic/core/quic_stream_test.cc index 031f648b..c2f50392 100644 --- a/gquiche/quic/core/quic_stream_test.cc +++ b/gquiche/quic/core/quic_stream_test.cc @@ -26,7 +26,6 @@ #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_mem_slice_storage.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/quic/platform/api/quic_test_mem_slice_vector.h" #include "gquiche/quic/test_tools/quic_config_peer.h" #include "gquiche/quic/test_tools/quic_connection_peer.h" #include "gquiche/quic/test_tools/quic_flow_controller_peer.h" @@ -59,16 +58,15 @@ class TestStream : public QuicStream { sequencer()->set_level_triggered(true); } - TestStream(PendingStream* pending, - QuicSession* session, - StreamType type, - bool is_static) - : QuicStream(pending, session, type, is_static) {} + TestStream(PendingStream* pending, QuicSession* session, bool is_static) + : QuicStream(pending, session, is_static) {} MOCK_METHOD(void, OnDataAvailable, (), (override)); MOCK_METHOD(void, OnCanWriteNewData, (), (override)); + MOCK_METHOD(void, OnWriteSideInDataRecvdState, (), (override)); + using QuicStream::CanWriteNewData; using QuicStream::CanWriteNewDataAfterData; using QuicStream::CloseWriteSide; @@ -88,11 +86,11 @@ class QuicStreamTest : public QuicTestWithParam { : zero_(QuicTime::Delta::Zero()), supported_versions_(AllSupportedVersions()) {} - void Initialize() { + void Initialize(Perspective perspective = Perspective::IS_SERVER) { ParsedQuicVersionVector version_vector; version_vector.push_back(GetParam()); connection_ = new StrictMock( - &helper_, &alarm_factory_, Perspective::IS_SERVER, version_vector); + &helper_, &alarm_factory_, perspective, version_vector); connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); session_ = std::make_unique>(connection_); session_->Initialize(); @@ -133,12 +131,9 @@ class QuicStreamTest : public QuicTestWithParam { } QuicConsumedData CloseStreamOnWriteError( - QuicStreamId id, - QuicByteCount /*write_length*/, - QuicStreamOffset /*offset*/, - StreamSendingState /*state*/, - TransmissionType /*type*/, - absl::optional /*level*/) { + QuicStreamId id, QuicByteCount /*write_length*/, + QuicStreamOffset /*offset*/, StreamSendingState /*state*/, + TransmissionType /*type*/, absl::optional /*level*/) { session_->ResetStream(id, QUIC_STREAM_CANCELLED); return QuicConsumedData(1, false); } @@ -167,33 +162,59 @@ class QuicStreamTest : public QuicTestWithParam { QuicStreamId kTestStreamId = GetNthClientInitiatedBidirectionalStreamId(GetParam().transport_version, 1); + const QuicStreamId kTestPendingStreamId = + GetNthClientInitiatedUnidirectionalStreamId(GetParam().transport_version, + 1); }; -INSTANTIATE_TEST_SUITE_P(QuicStreamTests, - QuicStreamTest, +INSTANTIATE_TEST_SUITE_P(QuicStreamTests, QuicStreamTest, ::testing::ValuesIn(AllSupportedVersions()), ::testing::PrintToStringParamName()); -TEST_P(QuicStreamTest, PendingStreamStaticness) { +using PendingStreamTest = QuicStreamTest; + +INSTANTIATE_TEST_SUITE_P(PendingStreamTests, PendingStreamTest, + ::testing::ValuesIn(CurrentSupportedHttp3Versions()), + ::testing::PrintToStringParamName()); + +TEST_P(PendingStreamTest, PendingStreamStaticness) { Initialize(); - PendingStream pending(kTestStreamId + 2, session_.get()); - TestStream stream(&pending, session_.get(), StreamType::BIDIRECTIONAL, false); + PendingStream pending(kTestPendingStreamId, session_.get()); + TestStream stream(&pending, session_.get(), false); EXPECT_FALSE(stream.is_static()); - PendingStream pending2(kTestStreamId + 3, session_.get()); - TestStream stream2(&pending2, session_.get(), StreamType::BIDIRECTIONAL, - true); + PendingStream pending2(kTestPendingStreamId + 4, session_.get()); + TestStream stream2(&pending2, session_.get(), true); EXPECT_TRUE(stream2.is_static()); } -TEST_P(QuicStreamTest, PendingStreamTooMuchData) { +TEST_P(PendingStreamTest, PendingStreamType) { Initialize(); - PendingStream pending(kTestStreamId + 2, session_.get()); + PendingStream pending(kTestPendingStreamId, session_.get()); + TestStream stream(&pending, session_.get(), false); + EXPECT_EQ(stream.type(), READ_UNIDIRECTIONAL); +} + +TEST_P(PendingStreamTest, PendingStreamTypeOnClient) { + Initialize(Perspective::IS_CLIENT); + + QuicStreamId server_initiated_pending_stream_id = + GetNthServerInitiatedUnidirectionalStreamId(session_->transport_version(), + 1); + PendingStream pending(server_initiated_pending_stream_id, session_.get()); + TestStream stream(&pending, session_.get(), false); + EXPECT_EQ(stream.type(), READ_UNIDIRECTIONAL); +} + +TEST_P(PendingStreamTest, PendingStreamTooMuchData) { + Initialize(); + + PendingStream pending(kTestPendingStreamId, session_.get()); // Receive a stream frame that violates flow control: the byte offset is // higher than the receive window offset. - QuicStreamFrame frame(kTestStreamId + 2, false, + QuicStreamFrame frame(kTestPendingStreamId, false, kInitialSessionFlowControlWindowForTest + 1, "."); // Stream should not accept the frame, and the connection should be closed. @@ -202,29 +223,43 @@ TEST_P(QuicStreamTest, PendingStreamTooMuchData) { pending.OnStreamFrame(frame); } -TEST_P(QuicStreamTest, PendingStreamTooMuchDataInRstStream) { +TEST_P(PendingStreamTest, PendingStreamTooMuchDataInRstStream) { Initialize(); - PendingStream pending(kTestStreamId + 2, session_.get()); + PendingStream pending1(kTestPendingStreamId, session_.get()); // Receive a rst stream frame that violates flow control: the byte offset is // higher than the receive window offset. - QuicRstStreamFrame frame(kInvalidControlFrameId, kTestStreamId + 2, - QUIC_STREAM_CANCELLED, - kInitialSessionFlowControlWindowForTest + 1); + QuicRstStreamFrame frame1(kInvalidControlFrameId, kTestPendingStreamId, + QUIC_STREAM_CANCELLED, + kInitialSessionFlowControlWindowForTest + 1); // Pending stream should not accept the frame, and the connection should be // closed. EXPECT_CALL(*connection_, CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); - pending.OnRstStreamFrame(frame); + pending1.OnRstStreamFrame(frame1); + + QuicStreamId bidirection_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + session_->transport_version(), Perspective::IS_CLIENT); + PendingStream pending2(bidirection_stream_id, session_.get()); + // Receive a rst stream frame that violates flow control: the byte offset is + // higher than the receive window offset. + QuicRstStreamFrame frame2(kInvalidControlFrameId, bidirection_stream_id, + QUIC_STREAM_CANCELLED, + kInitialSessionFlowControlWindowForTest + 1); + // Bidirectional Pending stream should not accept the frame, and the + // connection should be closed. + EXPECT_CALL(*connection_, + CloseConnection(QUIC_FLOW_CONTROL_RECEIVED_TOO_MUCH_DATA, _, _)); + pending2.OnRstStreamFrame(frame2); } -TEST_P(QuicStreamTest, PendingStreamRstStream) { +TEST_P(PendingStreamTest, PendingStreamRstStream) { Initialize(); - PendingStream pending(kTestStreamId + 2, session_.get()); + PendingStream pending(kTestPendingStreamId, session_.get()); QuicStreamOffset final_byte_offset = 7; - QuicRstStreamFrame frame(kInvalidControlFrameId, kTestStreamId + 2, + QuicRstStreamFrame frame(kInvalidControlFrameId, kTestPendingStreamId, QUIC_STREAM_CANCELLED, final_byte_offset); // Pending stream should accept the frame and not close the connection. @@ -232,19 +267,47 @@ TEST_P(QuicStreamTest, PendingStreamRstStream) { pending.OnRstStreamFrame(frame); } -TEST_P(QuicStreamTest, FromPendingStream) { +TEST_P(PendingStreamTest, PendingStreamWindowUpdate) { + Initialize(); + + QuicStreamId bidirection_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + session_->transport_version(), Perspective::IS_CLIENT); + PendingStream pending(bidirection_stream_id, session_.get()); + QuicWindowUpdateFrame frame(kInvalidControlFrameId, bidirection_stream_id, + kDefaultFlowControlSendWindow * 2); + pending.OnWindowUpdateFrame(frame); + TestStream stream(&pending, session_.get(), false); + + EXPECT_EQ(QuicStreamPeer::SendWindowSize(&stream), + kDefaultFlowControlSendWindow * 2); +} + +TEST_P(PendingStreamTest, PendingStreamStopSending) { + Initialize(); + + QuicStreamId bidirection_stream_id = QuicUtils::GetFirstBidirectionalStreamId( + session_->transport_version(), Perspective::IS_CLIENT); + PendingStream pending(bidirection_stream_id, session_.get()); + QuicResetStreamError error = + QuicResetStreamError::FromInternal(QUIC_STREAM_INTERNAL_ERROR); + pending.OnStopSending(error); + EXPECT_TRUE(pending.GetStopSendingErrorCode()); + auto actual_error = *pending.GetStopSendingErrorCode(); + EXPECT_EQ(actual_error, error); +} + +TEST_P(PendingStreamTest, FromPendingStream) { Initialize(); - PendingStream pending(kTestStreamId + 2, session_.get()); + PendingStream pending(kTestPendingStreamId, session_.get()); - QuicStreamFrame frame(kTestStreamId + 2, false, 2, "."); + QuicStreamFrame frame(kTestPendingStreamId, false, 2, "."); pending.OnStreamFrame(frame); pending.OnStreamFrame(frame); - QuicStreamFrame frame2(kTestStreamId + 2, true, 3, "."); + QuicStreamFrame frame2(kTestPendingStreamId, true, 3, "."); pending.OnStreamFrame(frame2); - TestStream stream(&pending, session_.get(), StreamType::READ_UNIDIRECTIONAL, - false); + TestStream stream(&pending, session_.get(), false); EXPECT_EQ(3, stream.num_frames_received()); EXPECT_EQ(3u, stream.stream_bytes_read()); EXPECT_EQ(1, stream.num_duplicate_frames_received()); @@ -254,19 +317,18 @@ TEST_P(QuicStreamTest, FromPendingStream) { session_->flow_controller()->highest_received_byte_offset()); } -TEST_P(QuicStreamTest, FromPendingStreamThenData) { +TEST_P(PendingStreamTest, FromPendingStreamThenData) { Initialize(); - PendingStream pending(kTestStreamId + 2, session_.get()); + PendingStream pending(kTestPendingStreamId, session_.get()); - QuicStreamFrame frame(kTestStreamId + 2, false, 2, "."); + QuicStreamFrame frame(kTestPendingStreamId, false, 2, "."); pending.OnStreamFrame(frame); - auto stream = new TestStream(&pending, session_.get(), - StreamType::READ_UNIDIRECTIONAL, false); + auto stream = new TestStream(&pending, session_.get(), false); session_->ActivateStream(absl::WrapUnique(stream)); - QuicStreamFrame frame2(kTestStreamId + 2, true, 3, "."); + QuicStreamFrame frame2(kTestPendingStreamId, true, 3, "."); stream->OnStreamFrame(frame2); EXPECT_EQ(2, stream->num_frames_received()); @@ -865,6 +927,7 @@ TEST_P(QuicStreamTest, StreamWaitsForAcks) { EXPECT_EQ(0u, QuicStreamPeer::SendBuffer(stream_).size()); // FIN is acked. + EXPECT_CALL(*stream_, OnWriteSideInDataRecvdState()); EXPECT_TRUE(stream_->OnStreamFrameAcked(18, 0, true, QuicTime::Delta::Zero(), QuicTime::Zero(), &newly_acked_length)); @@ -908,6 +971,7 @@ TEST_P(QuicStreamTest, StreamDataGetAckedOutOfOrder) { // FIN is not acked yet. EXPECT_TRUE(stream_->IsWaitingForAcks()); EXPECT_TRUE(session_->HasUnackedStreamData()); + EXPECT_CALL(*stream_, OnWriteSideInDataRecvdState()); EXPECT_TRUE(stream_->OnStreamFrameAcked(27, 0, true, QuicTime::Delta::Zero(), QuicTime::Zero(), &newly_acked_length)); @@ -977,8 +1041,11 @@ TEST_P(QuicStreamTest, RstFrameReceivedStreamNotFinishSending) { QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), QUIC_STREAM_CANCELLED, 9); - EXPECT_CALL(*session_, MaybeSendRstStreamFrame(stream_->id(), - QUIC_RST_ACKNOWLEDGEMENT, 9)); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), 9)); stream_->OnStreamReset(rst_frame); EXPECT_EQ(1u, QuicStreamPeer::SendBuffer(stream_).size()); // Stream stops waiting for acks as it does not finish sending and rst is @@ -1020,8 +1087,11 @@ TEST_P(QuicStreamTest, ConnectionClosed) { stream_->WriteOrBufferData(kData1, false, nullptr); EXPECT_TRUE(stream_->IsWaitingForAcks()); EXPECT_TRUE(session_->HasUnackedStreamData()); - EXPECT_CALL(*session_, MaybeSendRstStreamFrame(stream_->id(), - QUIC_RST_ACKNOWLEDGEMENT, 9)); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + stream_->id(), + QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), 9)); QuicConnectionPeer::SetConnectionClose(connection_); stream_->OnConnectionClosed(QUIC_INTERNAL_ERROR, ConnectionCloseSource::FROM_SELF); @@ -1178,14 +1248,17 @@ TEST_P(QuicStreamTest, WriteMemSlices) { SetQuicFlag(FLAGS_quic_buffered_data_threshold, 100); Initialize(); - char data[1024]; - std::vector> buffers; - buffers.push_back(std::make_pair(data, ABSL_ARRAYSIZE(data))); - buffers.push_back(std::make_pair(data, ABSL_ARRAYSIZE(data))); - QuicTestMemSliceVector vector1(buffers); - QuicTestMemSliceVector vector2(buffers); - QuicMemSliceSpan span1 = vector1.span(); - QuicMemSliceSpan span2 = vector2.span(); + constexpr QuicByteCount kDataSize = 1024; + QuicBufferAllocator* allocator = + connection_->helper()->GetStreamSendBufferAllocator(); + std::vector vector1; + vector1.push_back(QuicMemSlice(QuicBuffer(allocator, kDataSize))); + vector1.push_back(QuicMemSlice(QuicBuffer(allocator, kDataSize))); + std::vector vector2; + vector2.push_back(QuicMemSlice(QuicBuffer(allocator, kDataSize))); + vector2.push_back(QuicMemSlice(QuicBuffer(allocator, kDataSize))); + absl::Span span1(vector1); + absl::Span span2(vector2); EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) .WillOnce(InvokeWithoutArgs([this]() { @@ -1196,7 +1269,7 @@ TEST_P(QuicStreamTest, WriteMemSlices) { QuicConsumedData consumed = stream_->WriteMemSlices(span1, false); EXPECT_EQ(2048u, consumed.bytes_consumed); EXPECT_FALSE(consumed.fin_consumed); - EXPECT_EQ(2 * ABSL_ARRAYSIZE(data) - 100, stream_->BufferedDataBytes()); + EXPECT_EQ(2 * kDataSize - 100, stream_->BufferedDataBytes()); EXPECT_FALSE(stream_->fin_buffered()); EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)).Times(0); @@ -1204,12 +1277,11 @@ TEST_P(QuicStreamTest, WriteMemSlices) { consumed = stream_->WriteMemSlices(span2, true); EXPECT_EQ(0u, consumed.bytes_consumed); EXPECT_FALSE(consumed.fin_consumed); - EXPECT_EQ(2 * ABSL_ARRAYSIZE(data) - 100, stream_->BufferedDataBytes()); + EXPECT_EQ(2 * kDataSize - 100, stream_->BufferedDataBytes()); EXPECT_FALSE(stream_->fin_buffered()); QuicByteCount data_to_write = - 2 * ABSL_ARRAYSIZE(data) - 100 - - GetQuicFlag(FLAGS_quic_buffered_data_threshold) + 1; + 2 * kDataSize - 100 - GetQuicFlag(FLAGS_quic_buffered_data_threshold) + 1; EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) .WillOnce(InvokeWithoutArgs([this, data_to_write]() { return session_->ConsumeData(stream_->id(), data_to_write, 100u, NO_FIN, @@ -1225,8 +1297,7 @@ TEST_P(QuicStreamTest, WriteMemSlices) { consumed = stream_->WriteMemSlices(span2, true); EXPECT_EQ(2048u, consumed.bytes_consumed); EXPECT_TRUE(consumed.fin_consumed); - EXPECT_EQ(2 * ABSL_ARRAYSIZE(data) + - GetQuicFlag(FLAGS_quic_buffered_data_threshold) - 1, + EXPECT_EQ(2 * kDataSize + GetQuicFlag(FLAGS_quic_buffered_data_threshold) - 1, stream_->BufferedDataBytes()); EXPECT_TRUE(stream_->fin_buffered()); @@ -1242,26 +1313,20 @@ TEST_P(QuicStreamTest, WriteMemSlices) { TEST_P(QuicStreamTest, WriteMemSlicesReachStreamLimit) { Initialize(); QuicStreamPeer::SetStreamBytesWritten(kMaxStreamLength - 5u, stream_); - char data[5]; std::vector> buffers; - buffers.push_back(std::make_pair(data, ABSL_ARRAYSIZE(data))); - QuicTestMemSliceVector vector1(buffers); - QuicMemSliceSpan span1 = vector1.span(); + QuicMemSlice slice1 = MemSliceFromString("12345"); EXPECT_CALL(*session_, WritevData(_, _, _, _, _, _)) .WillOnce(InvokeWithoutArgs([this]() { return session_->ConsumeData(stream_->id(), 5u, 0u, NO_FIN, NOT_RETRANSMISSION, absl::nullopt); })); // There is no buffered data before, all data should be consumed. - QuicConsumedData consumed = stream_->WriteMemSlices(span1, false); + QuicConsumedData consumed = stream_->WriteMemSlice(std::move(slice1), false); EXPECT_EQ(5u, consumed.bytes_consumed); - std::vector> buffers2; - buffers2.push_back(std::make_pair(data, 1u)); - QuicTestMemSliceVector vector2(buffers); - QuicMemSliceSpan span2 = vector2.span(); + QuicMemSlice slice2 = MemSliceFromString("6"); EXPECT_CALL(*connection_, CloseConnection(QUIC_STREAM_LENGTH_OVERFLOW, _, _)); - EXPECT_QUIC_BUG(stream_->WriteMemSlices(span2, false), + EXPECT_QUIC_BUG(stream_->WriteMemSlice(std::move(slice2), false), "Write too many data via stream"); } @@ -1312,6 +1377,7 @@ TEST_P(QuicStreamTest, StreamDataGetAckedMultipleTimes) { EXPECT_TRUE(session_->HasUnackedStreamData()); // Ack Fin. + EXPECT_CALL(*stream_, OnWriteSideInDataRecvdState()).Times(1); EXPECT_TRUE(stream_->OnStreamFrameAcked(27, 0, true, QuicTime::Delta::Zero(), QuicTime::Zero(), &newly_acked_length)); @@ -1548,10 +1614,14 @@ TEST_P(QuicStreamTest, ResetStreamOnTtlExpiresRetransmitLostData) { // Verify stream gets reset because TTL expires. if (session_->version().UsesHttp3()) { EXPECT_CALL(*session_, - MaybeSendStopSendingFrame(_, QUIC_STREAM_TTL_EXPIRED)) + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_TTL_EXPIRED))) .Times(1); } - EXPECT_CALL(*session_, MaybeSendRstStreamFrame(_, QUIC_STREAM_TTL_EXPIRED, _)) + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_TTL_EXPIRED), _)) .Times(1); stream_->OnCanWrite(); } @@ -1572,10 +1642,14 @@ TEST_P(QuicStreamTest, ResetStreamOnTtlExpiresEarlyRetransmitData) { // Verify stream gets reset because TTL expires. if (session_->version().UsesHttp3()) { EXPECT_CALL(*session_, - MaybeSendStopSendingFrame(_, QUIC_STREAM_TTL_EXPIRED)) + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_TTL_EXPIRED))) .Times(1); } - EXPECT_CALL(*session_, MaybeSendRstStreamFrame(_, QUIC_STREAM_TTL_EXPIRED, _)) + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_TTL_EXPIRED), _)) .Times(1); stream_->RetransmitStreamData(0, 100, false, PTO_RETRANSMISSION); } @@ -1634,8 +1708,7 @@ TEST_P(QuicStreamTest, RstStreamFrameChangesCloseOffset) { TEST_P(QuicStreamTest, EmptyStreamFrameWithNoFin) { Initialize(); QuicStreamFrame empty_stream_frame(stream_->id(), false, 0, ""); - if (GetQuicReloadableFlag(quic_accept_empty_stream_frame_with_no_fin) && - stream_->version().HasIetfQuicFrames()) { + if (stream_->version().HasIetfQuicFrames()) { EXPECT_CALL(*connection_, CloseConnection(QUIC_EMPTY_STREAM_FRAME_NO_FIN, _, _)) .Times(0); @@ -1647,6 +1720,15 @@ TEST_P(QuicStreamTest, EmptyStreamFrameWithNoFin) { stream_->OnStreamFrame(empty_stream_frame); } +TEST_P(QuicStreamTest, SendRstWithCustomIetfCode) { + Initialize(); + QuicResetStreamError error(QUIC_STREAM_CANCELLED, 0x1234abcd); + EXPECT_CALL(*session_, MaybeSendRstStreamFrame(kTestStreamId, error, _)) + .Times(1); + stream_->ResetWithError(error); + EXPECT_TRUE(rst_sent()); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/quic_tag.cc b/gquiche/quic/core/quic_tag.cc index 5a22e4d9..32461fdc 100644 --- a/gquiche/quic/core/quic_tag.cc +++ b/gquiche/quic/core/quic_tag.cc @@ -12,7 +12,7 @@ #include "absl/strings/str_split.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { diff --git a/gquiche/quic/core/quic_time.cc b/gquiche/quic/core/quic_time.cc index 3b6c6890..fad77b07 100644 --- a/gquiche/quic/core/quic_time.cc +++ b/gquiche/quic/core/quic_time.cc @@ -21,11 +21,11 @@ std::string QuicTime::Delta::ToDebuggingValue() const { // For debugging purposes, always display the value with the highest precision // available. - if (absolute_value > kSecondInMicroseconds && + if (absolute_value >= kSecondInMicroseconds && absolute_value % kSecondInMicroseconds == 0) { return absl::StrCat(time_offset_ / kSecondInMicroseconds, "s"); } - if (absolute_value > kMillisecondInMicroseconds && + if (absolute_value >= kMillisecondInMicroseconds && absolute_value % kMillisecondInMicroseconds == 0) { return absl::StrCat(time_offset_ / kMillisecondInMicroseconds, "ms"); } diff --git a/gquiche/quic/core/quic_time_test.cc b/gquiche/quic/core/quic_time_test.cc index a75cb1f7..d2f1e487 100644 --- a/gquiche/quic/core/quic_time_test.cc +++ b/gquiche/quic/core/quic_time_test.cc @@ -87,8 +87,11 @@ TEST_F(QuicTimeDeltaTest, DebuggingValue) { const QuicTime::Delta one_ms = QuicTime::Delta::FromMilliseconds(1); const QuicTime::Delta one_s = QuicTime::Delta::FromSeconds(1); + EXPECT_EQ("1s", one_s.ToDebuggingValue()); EXPECT_EQ("3s", (3 * one_s).ToDebuggingValue()); + EXPECT_EQ("1ms", one_ms.ToDebuggingValue()); EXPECT_EQ("3ms", (3 * one_ms).ToDebuggingValue()); + EXPECT_EQ("1us", one_us.ToDebuggingValue()); EXPECT_EQ("3us", (3 * one_us).ToDebuggingValue()); EXPECT_EQ("3001us", (3 * one_ms + one_us).ToDebuggingValue()); diff --git a/gquiche/quic/core/quic_time_wait_list_manager.cc b/gquiche/quic/core/quic_time_wait_list_manager.cc index fb397567..08ad940d 100644 --- a/gquiche/quic/core/quic_time_wait_list_manager.cc +++ b/gquiche/quic/core/quic_time_wait_list_manager.cc @@ -22,16 +22,15 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_socket_address.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { // A very simple alarm that just informs the QuicTimeWaitListManager to clean // up old connection_ids. This alarm should be cancelled and deleted before // the QuicTimeWaitListManager is deleted. -class ConnectionIdCleanUpAlarm : public QuicAlarm::Delegate { +class ConnectionIdCleanUpAlarm : public QuicAlarm::DelegateWithoutContext { public: explicit ConnectionIdCleanUpAlarm( QuicTimeWaitListManager* time_wait_list_manager) @@ -92,9 +91,6 @@ QuicTimeWaitListManager::~QuicTimeWaitListManager() { QuicTimeWaitListManager::ConnectionIdMap::iterator QuicTimeWaitListManager::FindConnectionIdDataInMap( const QuicConnectionId& connection_id) { - if (!use_indirect_connection_id_map_) { - return connection_id_map_.find(connection_id); - } auto it = indirect_connection_id_map_.find(connection_id); if (it == indirect_connection_id_map_.end()) { return connection_id_map_.end(); @@ -107,12 +103,8 @@ void QuicTimeWaitListManager::AddConnectionIdDataToMap( int num_packets, TimeWaitAction action, TimeWaitConnectionInfo info) { - if (use_indirect_connection_id_map_) { - QUIC_RESTART_FLAG_COUNT_N(quic_time_wait_list_support_multiple_cid_v2, 1, - 3); - for (const auto& cid : info.active_connection_ids) { - indirect_connection_id_map_[cid] = canonical_connection_id; - } + for (const auto& cid : info.active_connection_ids) { + indirect_connection_id_map_[cid] = canonical_connection_id; } ConnectionIdData data(num_packets, clock_->ApproximateNow(), action, std::move(info)); @@ -122,24 +114,18 @@ void QuicTimeWaitListManager::AddConnectionIdDataToMap( void QuicTimeWaitListManager::RemoveConnectionDataFromMap( ConnectionIdMap::iterator it) { - if (use_indirect_connection_id_map_) { - QUIC_RESTART_FLAG_COUNT_N(quic_time_wait_list_support_multiple_cid_v2, 2, - 3); - for (const auto& cid : it->second.info.active_connection_ids) { - indirect_connection_id_map_.erase(cid); - } + for (const auto& cid : it->second.info.active_connection_ids) { + indirect_connection_id_map_.erase(cid); } connection_id_map_.erase(it); } void QuicTimeWaitListManager::AddConnectionIdToTimeWait( - QuicConnectionId connection_id, TimeWaitAction action, TimeWaitConnectionInfo info) { QUICHE_DCHECK(!info.active_connection_ids.empty()); const QuicConnectionId& canonical_connection_id = - use_indirect_connection_id_map_ ? info.active_connection_ids.front() - : connection_id; + info.active_connection_ids.front(); QUICHE_DCHECK(action != SEND_TERMINATION_PACKETS || !info.termination_packets.empty()); QUICHE_DCHECK(action != DO_NOTHING || info.ietf_quic); @@ -155,26 +141,18 @@ void QuicTimeWaitListManager::AddConnectionIdToTimeWait( GetQuicFlag(FLAGS_quic_time_wait_list_max_connections); QUICHE_DCHECK(connection_id_map_.empty() || num_connections() < static_cast(max_connections)); - if (use_indirect_connection_id_map_ && new_connection_id) { - QUIC_RESTART_FLAG_COUNT_N(quic_time_wait_list_support_multiple_cid_v2, 3, - 3); + if (new_connection_id) { for (const auto& cid : info.active_connection_ids) { visitor_->OnConnectionAddedToTimeWaitList(cid); } } AddConnectionIdDataToMap(canonical_connection_id, num_packets, action, std::move(info)); - if (!use_indirect_connection_id_map_ && new_connection_id) { - visitor_->OnConnectionAddedToTimeWaitList(canonical_connection_id); - } } bool QuicTimeWaitListManager::IsConnectionIdInTimeWait( QuicConnectionId connection_id) const { - if (use_indirect_connection_id_map_) { - return indirect_connection_id_map_.contains(connection_id); - } - return QuicContainsKey(connection_id_map_, connection_id); + return indirect_connection_id_map_.contains(connection_id); } void QuicTimeWaitListManager::OnBlockedWriterCanWrite() { @@ -325,6 +303,12 @@ void QuicTimeWaitListManager::SendPublicReset( if (ietf_quic) { std::unique_ptr ietf_reset_packet = BuildIetfStatelessResetPacket(connection_id, received_packet_length); + if (ietf_reset_packet == nullptr) { + // This could happen when trying to reject a short header packet of + // a connection which is in the time wait list (and with no termination + // packet). + return; + } QUIC_DVLOG(2) << "Dispatcher sending IETF reset packet for " << connection_id << std::endl << quiche::QuicheTextUtils::HexDump( @@ -386,6 +370,12 @@ bool QuicTimeWaitListManager::SendOrQueuePacket( QUIC_LOG(ERROR) << "Tried to send or queue a null packet"; return true; } + if (pending_packets_queue_.size() >= + GetQuicFlag(FLAGS_quic_time_wait_list_max_pending_packets)) { + // There are too many pending packets. + QUIC_CODE_COUNT(quic_too_many_pending_packets_in_time_wait); + return true; + } if (WriteToWire(packet.get())) { // Allow the packet to be deleted upon leaving this function. return true; diff --git a/gquiche/quic/core/quic_time_wait_list_manager.h b/gquiche/quic/core/quic_time_wait_list_manager.h index 863a4f2c..ea42c818 100644 --- a/gquiche/quic/core/quic_time_wait_list_manager.h +++ b/gquiche/quic/core/quic_time_wait_list_manager.h @@ -19,8 +19,8 @@ #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_types.h" -#include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_flags.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -99,14 +99,13 @@ class QUIC_NO_EXPORT QuicTimeWaitListManager QuicTimeWaitListManager& operator=(const QuicTimeWaitListManager&) = delete; ~QuicTimeWaitListManager() override; - // Adds the given connection_id to time wait state for time_wait_period_. - // If |termination_packets| are provided, copies of these packets will be sent - // when a packet with this connection ID is processed. Any termination packets - // will be move from |termination_packets| and will become owned by the - // manager. |action| specifies what the time wait list manager should do when - // processing packets of the connection. - virtual void AddConnectionIdToTimeWait(QuicConnectionId connection_id, - TimeWaitAction action, + // Adds the connection IDs in info to time wait state for time_wait_period_. + // If |info|.termination_packets are provided, copies of these packets will be + // sent when a packet with one of these connection IDs is processed. Any + // termination packets will be move from |info|.termination_packets and will + // become owned by the manager. |action| specifies what the time wait list + // manager should do when processing packets of the connection. + virtual void AddConnectionIdToTimeWait(TimeWaitAction action, TimeWaitConnectionInfo info); // Returns true if the connection_id is in time wait state, false otherwise. @@ -224,7 +223,7 @@ class QUIC_NO_EXPORT QuicTimeWaitListManager virtual bool SendOrQueuePacket(std::unique_ptr packet, const QuicPerPacketContext* packet_context); - const QuicCircularDeque>& + const quiche::QuicheCircularDeque>& pending_packets_queue() const { return pending_packets_queue_; } @@ -285,10 +284,11 @@ class QUIC_NO_EXPORT QuicTimeWaitListManager TimeWaitConnectionInfo info; }; - // QuicLinkedHashMap allows lookup by ConnectionId and traversal in add order. - using ConnectionIdMap = QuicLinkedHashMap; + // QuicheLinkedHashMap allows lookup by ConnectionId + // and traversal in add order. + using ConnectionIdMap = quiche::QuicheLinkedHashMap; // Do not use find/emplace/erase on this map directly. Use // FindConnectionIdDataInMap, AddConnectionIdDateToMap, // RemoveConnectionDataFromMap instead. @@ -318,7 +318,8 @@ class QUIC_NO_EXPORT QuicTimeWaitListManager // Pending termination packets that need to be sent out to the peer when we // are given a chance to write by the dispatcher. - QuicCircularDeque> pending_packets_queue_; + quiche::QuicheCircularDeque> + pending_packets_queue_; // Time period for which connection_ids should remain in time wait state. const QuicTime::Delta time_wait_period_; @@ -335,11 +336,6 @@ class QUIC_NO_EXPORT QuicTimeWaitListManager // Interface that manages blocked writers. Visitor* visitor_; - - // When this is default true, remove the connection_id argument of - // AddConnectionIdToTimeWait. - bool use_indirect_connection_id_map_ = - GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid_v2); }; } // namespace quic diff --git a/gquiche/quic/core/quic_time_wait_list_manager_test.cc b/gquiche/quic/core/quic_time_wait_list_manager_test.cc index d1619e36..78e59a70 100644 --- a/gquiche/quic/core/quic_time_wait_list_manager_test.cc +++ b/gquiche/quic/core/quic_time_wait_list_manager_test.cc @@ -157,7 +157,7 @@ class QuicTimeWaitListManagerTest : public QuicTest { termination_packets.push_back(std::unique_ptr( new QuicEncryptedPacket(nullptr, 0, false))); time_wait_list_manager_.AddConnectionIdToTimeWait( - connection_id, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, TimeWaitConnectionInfo(false, &termination_packets, {connection_id})); } @@ -167,9 +167,8 @@ class QuicTimeWaitListManagerTest : public QuicTest { QuicTimeWaitListManager::TimeWaitAction action, std::vector>* packets) { time_wait_list_manager_.AddConnectionIdToTimeWait( - connection_id, action, - TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), packets, - {connection_id})); + action, TimeWaitConnectionInfo(version.HasIetfInvariantHeader(), + packets, {connection_id})); } bool IsConnectionIdInTimeWait(QuicConnectionId connection_id) { @@ -204,13 +203,13 @@ class QuicTimeWaitListManagerTest : public QuicTest { bool ValidPublicResetPacketPredicate( QuicConnectionId expected_connection_id, - const testing::tuple& packet_buffer) { + const std::tuple& packet_buffer) { FramerVisitorCapturingPublicReset visitor(expected_connection_id); QuicFramer framer(AllSupportedVersions(), QuicTime::Zero(), Perspective::IS_CLIENT, kQuicDefaultConnectionIdLength); framer.set_visitor(&visitor); - QuicEncryptedPacket encrypted(testing::get<0>(packet_buffer), - testing::get<1>(packet_buffer)); + QuicEncryptedPacket encrypted(std::get<0>(packet_buffer), + std::get<1>(packet_buffer)); framer.ProcessPacket(encrypted); QuicPublicResetPacket packet = visitor.public_reset_packet(); bool public_reset_is_valid = @@ -230,10 +229,10 @@ bool ValidPublicResetPacketPredicate( return public_reset_is_valid || stateless_reset_is_valid; } -Matcher> PublicResetPacketEq( +Matcher> PublicResetPacketEq( QuicConnectionId connection_id) { return Truly( - [connection_id](const testing::tuple packet_buffer) { + [connection_id](const std::tuple packet_buffer) { return ValidPublicResetPacketPredicate(connection_id, packet_buffer); }); } @@ -450,10 +449,6 @@ TEST_F(QuicTimeWaitListManagerTest, CleanUpOldConnectionIds) { TEST_F(QuicTimeWaitListManagerTest, CleanUpOldConnectionIdsForMultipleConnectionIdsPerConnection) { - if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid_v2)) { - return; - } - connection_id_ = TestConnectionId(7); const size_t kConnectionCloseLength = 100; EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); @@ -467,7 +462,7 @@ TEST_F(QuicTimeWaitListManagerTest, std::vector active_connection_ids{connection_id_, TestConnectionId(8)}; time_wait_list_manager_.AddConnectionIdToTimeWait( - connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, active_connection_ids, QuicTime::Delta::Zero())); @@ -675,7 +670,7 @@ TEST_F(QuicTimeWaitListManagerTest, std::unique_ptr(new QuicEncryptedPacket( new char[kConnectionCloseLength], kConnectionCloseLength, true))); time_wait_list_manager_.AddConnectionIdToTimeWait( - connection_id_, QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, + QuicTimeWaitListManager::SEND_TERMINATION_PACKETS, TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, {connection_id_})); @@ -701,7 +696,7 @@ TEST_F(QuicTimeWaitListManagerTest, new char[kConnectionCloseLength], kConnectionCloseLength, true))); // Add a CONNECTION_CLOSE termination packet. time_wait_list_manager_.AddConnectionIdToTimeWait( - connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, {connection_id_})); EXPECT_CALL(writer_, WritePacket(_, kConnectionCloseLength, @@ -717,10 +712,6 @@ TEST_F(QuicTimeWaitListManagerTest, TEST_F(QuicTimeWaitListManagerTest, SendConnectionClosePacketsForMultipleConnectionIds) { - if (!GetQuicRestartFlag(quic_time_wait_list_support_multiple_cid_v2)) { - return; - } - connection_id_ = TestConnectionId(7); const size_t kConnectionCloseLength = 100; EXPECT_CALL(visitor_, OnConnectionAddedToTimeWaitList(connection_id_)); @@ -734,7 +725,7 @@ TEST_F(QuicTimeWaitListManagerTest, std::vector active_connection_ids{connection_id_, TestConnectionId(8)}; time_wait_list_manager_.AddConnectionIdToTimeWait( - connection_id_, QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, + QuicTimeWaitListManager::SEND_CONNECTION_CLOSE_PACKETS, TimeWaitConnectionInfo(/*ietf_quic=*/true, &termination_packets, active_connection_ids, QuicTime::Delta::Zero())); @@ -750,11 +741,48 @@ TEST_F(QuicTimeWaitListManagerTest, } } +// Regression test for b/184053898. +TEST_F(QuicTimeWaitListManagerTest, DonotCrashOnNullStatelessReset) { + // Received a packet with length < + // QuicFramer::GetMinStatelessResetPacketLength(), and this will result in a + // null stateless reset. + time_wait_list_manager_.SendPublicReset( + self_address_, peer_address_, TestConnectionId(1), + /*ietf_quic=*/true, + /*received_packet_length=*/ + QuicFramer::GetMinStatelessResetPacketLength() - 1, + /*packet_context=*/nullptr); +} + TEST_F(QuicTimeWaitListManagerTest, SendOrQueueNullPacket) { QuicTimeWaitListManagerPeer::SendOrQueuePacket(&time_wait_list_manager_, nullptr, nullptr); } +TEST_F(QuicTimeWaitListManagerTest, TooManyPendingPackets) { + SetQuicFlag(FLAGS_quic_time_wait_list_max_pending_packets, 5); + const size_t kNumOfUnProcessablePackets = 2048; + EXPECT_CALL(visitor_, OnWriteBlocked(&time_wait_list_manager_)) + .Times(testing::AnyNumber()); + // Write block for the next packets. + EXPECT_CALL(writer_, + WritePacket(_, _, self_address_.host(), peer_address_, _)) + .With(Args<0, 1>(PublicResetPacketEq(TestConnectionId(1)))) + .WillOnce(DoAll(Assign(&writer_is_blocked_, true), + Return(WriteResult(WRITE_STATUS_BLOCKED, EAGAIN)))); + for (size_t i = 0; i < kNumOfUnProcessablePackets; ++i) { + time_wait_list_manager_.SendPublicReset( + self_address_, peer_address_, TestConnectionId(1), + /*ietf_quic=*/true, + /*received_packet_length=*/ + QuicFramer::GetMinStatelessResetPacketLength() + 1, + /*packet_context=*/nullptr); + } + // Verify pending packet queue size is limited. + EXPECT_EQ(5u, QuicTimeWaitListManagerPeer::PendingPacketsQueueSize( + &time_wait_list_manager_)); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/quic_types.cc b/gquiche/quic/core/quic_types.cc index 24ec0d93..4f32bc0c 100644 --- a/gquiche/quic/core/quic_types.cc +++ b/gquiche/quic/core/quic_types.cc @@ -8,6 +8,7 @@ #include "absl/strings/str_cat.h" #include "gquiche/quic/core/quic_error_codes.h" +#include "gquiche/common/print_elements.h" namespace quic { @@ -225,7 +226,6 @@ std::string TransmissionTypeToString(TransmissionType transmission_type) { return "INVALID_TRANSMISSION_TYPE"; } return absl::StrCat("Unknown(", static_cast(transmission_type), ")"); - break; } } @@ -267,7 +267,6 @@ std::string MessageStatusToString(MessageStatus message_status) { RETURN_STRING_LITERAL(MESSAGE_STATUS_INTERNAL_ERROR); default: return absl::StrCat("Unknown(", static_cast(message_status), ")"); - break; } } @@ -292,7 +291,6 @@ std::string PacketNumberSpaceToString(PacketNumberSpace packet_number_space) { default: return absl::StrCat("Unknown(", static_cast(packet_number_space), ")"); - break; } } @@ -320,7 +318,6 @@ std::string EncryptionLevelToString(EncryptionLevel level) { RETURN_STRING_LITERAL(ENCRYPTION_FORWARD_SECURE); default: return absl::StrCat("Unknown(", static_cast(level), ")"); - break; } } @@ -329,6 +326,25 @@ std::ostream& operator<<(std::ostream& os, EncryptionLevel level) { return os; } +absl::string_view ClientCertModeToString(ClientCertMode mode) { +#define RETURN_REASON_LITERAL(x) \ + case ClientCertMode::x: \ + return #x + switch (mode) { + RETURN_REASON_LITERAL(kNone); + RETURN_REASON_LITERAL(kRequest); + RETURN_REASON_LITERAL(kRequire); + default: + return ""; + } +#undef RETURN_REASON_LITERAL +} + +std::ostream& operator<<(std::ostream& os, ClientCertMode mode) { + os << ClientCertModeToString(mode); + return os; +} + std::string QuicConnectionCloseTypeString(QuicConnectionCloseType type) { switch (type) { RETURN_STRING_LITERAL(GOOGLE_QUIC_CONNECTION_CLOSE); @@ -336,7 +352,6 @@ std::string QuicConnectionCloseTypeString(QuicConnectionCloseType type) { RETURN_STRING_LITERAL(IETF_QUIC_APPLICATION_CONNECTION_CLOSE); default: return absl::StrCat("Unknown(", static_cast(type), ")"); - break; } } @@ -378,7 +393,6 @@ std::string KeyUpdateReasonString(KeyUpdateReason reason) { RETURN_REASON_LITERAL(kLocalKeyUpdateLimitOverride); default: return absl::StrCat("Unknown(", static_cast(reason), ")"); - break; } #undef RETURN_REASON_LITERAL } @@ -388,6 +402,25 @@ std::ostream& operator<<(std::ostream& os, const KeyUpdateReason reason) { return os; } +bool operator==(const ParsedClientHello& a, const ParsedClientHello& b) { + return a.sni == b.sni && a.uaid == b.uaid && a.alpns == b.alpns && + a.legacy_version_encapsulation_inner_packet == + b.legacy_version_encapsulation_inner_packet && + a.retry_token == b.retry_token && + a.resumption_attempted == b.resumption_attempted && + a.early_data_attempted == b.early_data_attempted; +} + +std::ostream& operator<<(std::ostream& os, + const ParsedClientHello& parsed_chlo) { + os << "{ sni:" << parsed_chlo.sni << ", uaid:" << parsed_chlo.uaid + << ", alpns:" << quiche::PrintElements(parsed_chlo.alpns) + << ", len(retry_token):" << parsed_chlo.retry_token.size() + << ", len(inner_packet):" + << parsed_chlo.legacy_version_encapsulation_inner_packet.size() << " }"; + return os; +} + #undef RETURN_STRING_LITERAL // undef for jumbo builds } // namespace quic diff --git a/gquiche/quic/core/quic_types.h b/gquiche/quic/core/quic_types.h index c8660793..c37b10c5 100644 --- a/gquiche/quic/core/quic_types.h +++ b/gquiche/quic/core/quic_types.h @@ -12,6 +12,7 @@ #include #include +#include "absl/container/inlined_vector.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_packet_number.h" @@ -25,7 +26,13 @@ namespace quic { using QuicPacketLength = uint16_t; using QuicControlFrameId = uint32_t; using QuicMessageId = uint32_t; -using QuicDatagramFlowId = uint64_t; + +// TODO(b/181256914) replace QuicDatagramStreamId with QuicStreamId once we +// remove support for draft-ietf-masque-h3-datagram-00 in favor of later drafts. +using QuicDatagramStreamId = uint64_t; +using QuicDatagramContextId = uint64_t; +// Note that for draft-ietf-masque-h3-datagram-00, we represent the flow ID as a +// QuicDatagramStreamId. // IMPORTANT: IETF QUIC defines stream IDs and stream counts as being unsigned // 62-bit numbers. However, we have decided to only support up to 2^32-1 streams @@ -49,6 +56,10 @@ using StatelessResetToken = std::array; // WebTransport session IDs are stream IDs. using WebTransportSessionId = uint64_t; +// WebTransport stream reset codes are 8-bit. +using WebTransportStreamError = uint8_t; +// WebTransport session error codes are 32-bit. +using WebTransportSessionError = uint32_t; enum : size_t { kQuicPathFrameBufferSize = 8 }; using QuicPathFrameBuffer = std::array; @@ -72,8 +83,7 @@ struct QUIC_EXPORT_PRIVATE QuicConsumedData { // default gtest object printer to read uninitialize memory. So we need // to teach gtest how to print this object. QUIC_EXPORT_PRIVATE friend std::ostream& operator<<( - std::ostream& os, - const QuicConsumedData& s); + std::ostream& os, const QuicConsumedData& s); // How many bytes were consumed. size_t bytes_consumed; @@ -191,8 +201,7 @@ QUIC_EXPORT_PRIVATE std::string TransmissionTypeToString( TransmissionType transmission_type); QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - TransmissionType transmission_type); + std::ostream& os, TransmissionType transmission_type); enum HasRetransmittableData : uint8_t { NO_RETRANSMITTABLE_DATA, @@ -213,8 +222,7 @@ enum class ConnectionCloseSource { FROM_PEER, FROM_SELF }; QUIC_EXPORT_PRIVATE std::string ConnectionCloseSourceToString( ConnectionCloseSource connection_close_source); QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - const ConnectionCloseSource& connection_close_source); + std::ostream& os, const ConnectionCloseSource& connection_close_source); // Should a connection be closed silently or not. enum class ConnectionCloseBehavior { @@ -226,8 +234,7 @@ enum class ConnectionCloseBehavior { QUIC_EXPORT_PRIVATE std::string ConnectionCloseBehaviorToString( ConnectionCloseBehavior connection_close_behavior); QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - const ConnectionCloseBehavior& connection_close_behavior); + std::ostream& os, const ConnectionCloseBehavior& connection_close_behavior); enum QuicFrameType : uint8_t { // Regular frame types. The values set here cannot change without the @@ -331,7 +338,12 @@ enum QuicIetfFrameType : uint64_t { IETF_EXTENSION_MESSAGE_V99 = 0x31, // An QUIC extension frame for sender control of acknowledgement delays - IETF_ACK_FREQUENCY = 0xaf + IETF_ACK_FREQUENCY = 0xaf, + + // A QUIC extension frame which augments the IETF_ACK frame definition with + // packet receive timestamps. + // TODO(ianswett): Determine a proper value to replace this temporary value. + IETF_ACK_RECEIVE_TIMESTAMPS = 0x22, }; QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, const QuicIetfFrameType& c); @@ -474,12 +486,18 @@ QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, // Enumeration of whether a server endpoint will request a client certificate, // and whether that endpoint requires a valid client certificate to establish a // connection. -enum class ClientCertMode { +enum class ClientCertMode : uint8_t { kNone, // Do not request a client certificate. Default server behavior. kRequest, // Request a certificate, but allow unauthenticated connections. kRequire, // Require clients to provide a valid certificate. }; +QUIC_EXPORT_PRIVATE absl::string_view ClientCertModeToString( + ClientCertMode mode); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, + ClientCertMode mode); + enum AddressChangeType : uint8_t { // IP address and port remain unchanged. NO_CHANGE, @@ -567,8 +585,7 @@ struct QUIC_EXPORT_PRIVATE AckedPacket { receive_timestamp(receive_timestamp) {} friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - const AckedPacket& acked_packet); + std::ostream& os, const AckedPacket& acked_packet); QuicPacketNumber packet_number; // Number of bytes sent in the packet that was acknowledged. @@ -580,7 +597,7 @@ struct QUIC_EXPORT_PRIVATE AckedPacket { }; // A vector of acked packets. -using AckedPacketVector = QuicInlinedVector; +using AckedPacketVector = absl::InlinedVector; // Information about a newly lost packet. struct QUIC_EXPORT_PRIVATE LostPacket { @@ -588,8 +605,7 @@ struct QUIC_EXPORT_PRIVATE LostPacket { : packet_number(packet_number), bytes_lost(bytes_lost) {} friend QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - const LostPacket& lost_packet); + std::ostream& os, const LostPacket& lost_packet); QuicPacketNumber packet_number; // Number of bytes sent in the packet that was lost. @@ -597,7 +613,7 @@ struct QUIC_EXPORT_PRIVATE LostPacket { }; // A vector of lost packets. -using LostPacketVector = QuicInlinedVector; +using LostPacketVector = absl::InlinedVector; // Please note, this value cannot used directly for packet serialization. enum QuicLongHeaderType : uint8_t { @@ -674,7 +690,7 @@ enum WriteStreamDataResult { WRITE_FAILED, // Trying to write nonexistent data of a stream }; -enum StreamType { +enum StreamType : uint8_t { // Bidirectional streams allow for data to be sent in both directions. BIDIRECTIONAL, @@ -737,8 +753,7 @@ enum QuicConnectionCloseType { }; QUIC_EXPORT_PRIVATE std::ostream& operator<<( - std::ostream& os, - const QuicConnectionCloseType type); + std::ostream& os, const QuicConnectionCloseType type); QUIC_EXPORT_PRIVATE std::string QuicConnectionCloseTypeString( QuicConnectionCloseType type); @@ -826,6 +841,50 @@ QUIC_EXPORT_PRIVATE std::ostream& operator<<(std::ostream& os, QUIC_EXPORT_PRIVATE std::string KeyUpdateReasonString(KeyUpdateReason reason); +// QuicSSLConfig contains configurations to be applied on a SSL object, which +// overrides the configurations in SSL_CTX. +struct QUIC_NO_EXPORT QuicSSLConfig { + // Whether TLS early data should be enabled. If not set, default to enabled. + absl::optional early_data_enabled; + // Whether TLS session tickets are supported. If not set, default to + // supported. + absl::optional disable_ticket_support; + // If set, used to configure the SSL object with + // SSL_set_signing_algorithm_prefs. + absl::optional> signing_algorithm_prefs; + // Client certificate mode for mTLS support. Only used at server side. + ClientCertMode client_cert_mode = ClientCertMode::kNone; +}; + +// QuicDelayedSSLConfig contains a subset of SSL config that can be applied +// after BoringSSL's early select certificate callback. This overwrites all SSL +// configs applied before cert selection. +struct QUIC_NO_EXPORT QuicDelayedSSLConfig { + // Client certificate mode for mTLS support. Only used at server side. + // absl::nullopt means do not change client certificate mode. + absl::optional client_cert_mode; +}; + +// ParsedClientHello contains client hello information extracted from a fully +// received client hello. +struct QUIC_NO_EXPORT ParsedClientHello { + std::string sni; // QUIC crypto and TLS. + std::string uaid; // QUIC crypto only. + std::vector alpns; // QUIC crypto and TLS. + std::string legacy_version_encapsulation_inner_packet; // QUIC crypto only. + // The unvalidated retry token from the last received packet of a potentially + // multi-packet client hello. TLS only. + std::string retry_token; + bool resumption_attempted = false; // TLS only. + bool early_data_attempted = false; // TLS only. +}; + +QUIC_EXPORT_PRIVATE bool operator==(const ParsedClientHello& a, + const ParsedClientHello& b); + +QUIC_EXPORT_PRIVATE std::ostream& operator<<( + std::ostream& os, const ParsedClientHello& parsed_chlo); + } // namespace quic #endif // QUICHE_QUIC_CORE_QUIC_TYPES_H_ diff --git a/gquiche/quic/core/quic_unacked_packet_map.cc b/gquiche/quic/core/quic_unacked_packet_map.cc index 00be8ca6..b69d82a7 100644 --- a/gquiche/quic/core/quic_unacked_packet_map.cc +++ b/gquiche/quic/core/quic_unacked_packet_map.cc @@ -8,6 +8,7 @@ #include #include +#include "absl/container/inlined_vector.h" #include "gquiche/quic/core/quic_connection_stats.h" #include "gquiche/quic/core/quic_packet_number.h" #include "gquiche/quic/core/quic_types.h" @@ -158,7 +159,13 @@ void QuicUnackedPacketMap::AddSentPacket(SerializedPacket* mutable_packet, largest_sent_largest_acked_.UpdateMax(packet.largest_acked); if (!measure_rtt) { - QUIC_BUG_IF(quic_bug_12645_2, set_in_flight); + QUIC_BUG_IF(quic_bug_12645_2, set_in_flight) + << "Packet " << mutable_packet->packet_number << ", transmission type " + << TransmissionTypeToString(mutable_packet->transmission_type) + << ", retransmittable frames: " + << QuicFramesToString(mutable_packet->retransmittable_frames) + << ", nonretransmittable_frames: " + << QuicFramesToString(mutable_packet->nonretransmittable_frames); info.state = NOT_CONTRIBUTING_RTT; } @@ -174,7 +181,7 @@ void QuicUnackedPacketMap::AddSentPacket(SerializedPacket* mutable_packet, last_inflight_packet_sent_time_ = sent_time; last_inflight_packets_sent_time_[packet_number_space] = sent_time; } - unacked_packets_.push_back(info); + unacked_packets_.push_back(std::move(info)); // Swap the retransmittable frames to avoid allocations. // TODO(ianswett): Could use emplace_back when Chromium can. if (has_crypto_handshake) { @@ -325,9 +332,9 @@ void QuicUnackedPacketMap::RemoveFromInFlight(QuicPacketNumber packet_number) { RemoveFromInFlight(info); } -QuicInlinedVector +absl::InlinedVector QuicUnackedPacketMap::NeuterUnencryptedPackets() { - QuicInlinedVector neutered_packets; + absl::InlinedVector neutered_packets; QuicPacketNumber packet_number = GetLeastUnacked(); for (QuicUnackedPacketMap::iterator it = begin(); it != end(); ++it, ++packet_number) { @@ -353,9 +360,9 @@ QuicUnackedPacketMap::NeuterUnencryptedPackets() { return neutered_packets; } -QuicInlinedVector +absl::InlinedVector QuicUnackedPacketMap::NeuterHandshakePackets() { - QuicInlinedVector neutered_packets; + absl::InlinedVector neutered_packets; QuicPacketNumber packet_number = GetLeastUnacked(); for (QuicUnackedPacketMap::iterator it = begin(); it != end(); ++it, ++packet_number) { diff --git a/gquiche/quic/core/quic_unacked_packet_map.h b/gquiche/quic/core/quic_unacked_packet_map.h index eb7e48d3..f78893a9 100644 --- a/gquiche/quic/core/quic_unacked_packet_map.h +++ b/gquiche/quic/core/quic_unacked_packet_map.h @@ -7,15 +7,15 @@ #include #include -#include +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_transmission_info.h" #include "gquiche/quic/core/session_notifier_interface.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_flags.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -70,12 +70,12 @@ class QUIC_EXPORT_PRIVATE QuicUnackedPacketMap { // Called to neuter all unencrypted packets to ensure they do not get // retransmitted. Returns a vector of neutered packet numbers. - QuicInlinedVector NeuterUnencryptedPackets(); + absl::InlinedVector NeuterUnencryptedPackets(); // Called to neuter packets in handshake packet number space to ensure they do // not get retransmitted. Returns a vector of neutered packet numbers. // TODO(fayang): Consider to combine this with NeuterUnencryptedPackets. - QuicInlinedVector NeuterHandshakePackets(); + absl::InlinedVector NeuterHandshakePackets(); // Returns true if |packet_number| has retransmittable frames. This will // return false if all frames of this packet are either non-retransmittable or @@ -113,10 +113,10 @@ class QUIC_EXPORT_PRIVATE QuicUnackedPacketMap { QuicPacketNumber GetLeastUnacked() const; using const_iterator = - QuicCircularDeque::const_iterator; + quiche::QuicheCircularDeque::const_iterator; using const_reverse_iterator = - QuicCircularDeque::const_reverse_iterator; - using iterator = QuicCircularDeque::iterator; + quiche::QuicheCircularDeque::const_reverse_iterator; + using iterator = quiche::QuicheCircularDeque::iterator; const_iterator begin() const { return unacked_packets_.begin(); } const_iterator end() const { return unacked_packets_.end(); } @@ -300,7 +300,7 @@ class QUIC_EXPORT_PRIVATE QuicUnackedPacketMap { // If the old packet is acked before the new packet, then the old entry will // be removed from the map and the new entry's retransmittable frames will be // set to nullptr. - QuicCircularDeque unacked_packets_; + quiche::QuicheCircularDeque unacked_packets_; // The packet at the 0th index of unacked_packets_. QuicPacketNumber least_unacked_; diff --git a/gquiche/quic/core/quic_utils.cc b/gquiche/quic/core/quic_utils.cc index 3c074b8c..385ed54a 100644 --- a/gquiche/quic/core/quic_utils.cc +++ b/gquiche/quic/core/quic_utils.cc @@ -14,6 +14,7 @@ #include "absl/base/optimization.h" #include "absl/numeric/int128.h" #include "absl/strings/string_view.h" +#include "openssl/sha.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_constants.h" #include "gquiche/quic/core/quic_types.h" @@ -21,8 +22,9 @@ #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/quic/platform/api/quic_prefetch.h" +#include "gquiche/quic/platform/api/quic_mem_slice.h" #include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/common/platform/api/quiche_prefetch.h" #include "gquiche/common/quiche_endian.h" namespace quic { @@ -266,9 +268,9 @@ void QuicUtils::CopyToBuffer(const struct iovec* iov, char* next_base = static_cast(iov[iovnum + 1].iov_base); // Prefetch 2 cachelines worth of data to get the prefetcher started; leave // it to the hardware prefetcher after that. - QuicPrefetchT0(next_base); + quiche::QuichePrefetchT0(next_base); if (iov[iovnum + 1].iov_len >= 64) { - QuicPrefetchT0(next_base + ABSL_CACHELINE_SIZE); + quiche::QuichePrefetchT0(next_base + ABSL_CACHELINE_SIZE); } } @@ -308,7 +310,6 @@ bool QuicUtils::IsRetransmittableFrame(QuicFrameType type) { case MTU_DISCOVERY_FRAME: case PATH_CHALLENGE_FRAME: case PATH_RESPONSE_FRAME: - case NEW_CONNECTION_ID_FRAME: return false; default: return true; @@ -694,6 +695,32 @@ bool QuicUtils::IsProbingFrame(QuicFrameType type) { } } +// static +bool QuicUtils::IsAckElicitingFrame(QuicFrameType type) { + switch (type) { + case PADDING_FRAME: + case STOP_WAITING_FRAME: + case ACK_FRAME: + case CONNECTION_CLOSE_FRAME: + return false; + default: + return true; + } +} + +// static +bool QuicUtils::AreStatelessResetTokensEqual( + const StatelessResetToken& token1, + const StatelessResetToken& token2) { + char byte = 0; + for (size_t i = 0; i < kStatelessResetTokenLength; i++) { + // This avoids compiler optimizations that could make us stop comparing + // after we find a byte that doesn't match. + byte |= (token1[i] ^ token2[i]); + } + return byte == 0; +} + bool IsValidWebTransportSessionId(WebTransportSessionId id, ParsedQuicVersion version) { QUICHE_DCHECK(version.UsesHttp3()); @@ -702,5 +729,21 @@ bool IsValidWebTransportSessionId(WebTransportSessionId id, QuicUtils::IsClientInitiatedStreamId(version.transport_version, id); } +QuicByteCount MemSliceSpanTotalSize(absl::Span span) { + QuicByteCount total = 0; + for (const QuicMemSlice& slice : span) { + total += slice.length(); + } + return total; +} + +std::string RawSha256(absl::string_view input) { + std::string raw_hash; + raw_hash.resize(SHA256_DIGEST_LENGTH); + SHA256(reinterpret_cast(input.data()), input.size(), + reinterpret_cast(&raw_hash[0])); + return raw_hash; +} + #undef RETURN_STRING_LITERAL // undef for jumbo builds } // namespace quic diff --git a/gquiche/quic/core/quic_utils.h b/gquiche/quic/core/quic_utils.h index cbb0b981..5faa949d 100644 --- a/gquiche/quic/core/quic_utils.h +++ b/gquiche/quic/core/quic_utils.h @@ -13,6 +13,7 @@ #include "absl/numeric/int128.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "gquiche/quic/core/crypto/quic_random.h" #include "gquiche/quic/core/frames/quic_frame.h" #include "gquiche/quic/core/quic_connection_id.h" @@ -21,6 +22,7 @@ #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_iovec.h" +#include "gquiche/quic/platform/api/quic_mem_slice.h" #include "gquiche/quic/platform/api/quic_socket_address.h" namespace quic { @@ -237,6 +239,14 @@ class QUIC_EXPORT_PRIVATE QuicUtils { // Return true if this frame is an IETF probing frame. static bool IsProbingFrame(QuicFrameType type); + + // Return true if the two stateless reset tokens are equal. Performs the + // comparison in constant time. + static bool AreStatelessResetTokensEqual(const StatelessResetToken& token1, + const StatelessResetToken& token2); + + // Return ture if this frame is an ack-eliciting frame. + static bool IsAckElicitingFrame(QuicFrameType type); }; // Returns true if the specific ID is a valid WebTransport session ID that our @@ -244,6 +254,11 @@ class QUIC_EXPORT_PRIVATE QuicUtils { bool IsValidWebTransportSessionId(WebTransportSessionId id, ParsedQuicVersion transport_version); +QuicByteCount MemSliceSpanTotalSize(absl::Span span); + +// Computes a SHA-256 hash and returns the raw bytes of the hash. +QUIC_EXPORT_PRIVATE std::string RawSha256(absl::string_view input); + template class QUIC_EXPORT_PRIVATE BitMask { public: diff --git a/gquiche/quic/core/quic_utils_test.cc b/gquiche/quic/core/quic_utils_test.cc index 210167d4..293102ac 100644 --- a/gquiche/quic/core/quic_utils_test.cc +++ b/gquiche/quic/core/quic_utils_test.cc @@ -320,6 +320,8 @@ TEST_F(QuicUtilsTest, StatelessResetToken) { QuicUtils::GenerateStatelessResetToken(connection_id2); EXPECT_EQ(token1a, token1b); EXPECT_NE(token1a, token2); + EXPECT_TRUE(QuicUtils::AreStatelessResetTokensEqual(token1a, token1b)); + EXPECT_FALSE(QuicUtils::AreStatelessResetTokensEqual(token1a, token2)); } enum class TestEnumClassBit : uint8_t { diff --git a/gquiche/quic/core/quic_version_manager.cc b/gquiche/quic/core/quic_version_manager.cc index 6194712e..e2dc5527 100644 --- a/gquiche/quic/core/quic_version_manager.cc +++ b/gquiche/quic/core/quic_version_manager.cc @@ -15,18 +15,7 @@ namespace quic { QuicVersionManager::QuicVersionManager( ParsedQuicVersionVector supported_versions) - : enable_version_rfcv1_(GetQuicReloadableFlag(quic_enable_version_rfcv1)), - disable_version_draft_29_( - GetQuicReloadableFlag(quic_disable_version_draft_29)), - disable_version_t051_(GetQuicReloadableFlag(quic_disable_version_t051)), - disable_version_q050_(GetQuicReloadableFlag(quic_disable_version_q050)), - disable_version_q046_(GetQuicReloadableFlag(quic_disable_version_q046)), - disable_version_q043_(GetQuicReloadableFlag(quic_disable_version_q043)), - allowed_supported_versions_(std::move(supported_versions)) { - static_assert(SupportedVersions().size() == 6u, - "Supported versions out of sync"); - RefilterSupportedVersions(); -} + : allowed_supported_versions_(std::move(supported_versions)) {} QuicVersionManager::~QuicVersionManager() {} @@ -35,6 +24,12 @@ const ParsedQuicVersionVector& QuicVersionManager::GetSupportedVersions() { return filtered_supported_versions_; } +const ParsedQuicVersionVector& +QuicVersionManager::GetSupportedVersionsWithOnlyHttp3() { + MaybeRefilterSupportedVersions(); + return filtered_supported_versions_with_http3_; +} + const ParsedQuicVersionVector& QuicVersionManager::GetSupportedVersionsWithQuicCrypto() { MaybeRefilterSupportedVersions(); @@ -47,24 +42,21 @@ const std::vector& QuicVersionManager::GetSupportedAlpns() { } void QuicVersionManager::MaybeRefilterSupportedVersions() { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync"); - if (enable_version_rfcv1_ != - GetQuicReloadableFlag(quic_enable_version_rfcv1) || + if (disable_version_rfcv1_ != + GetQuicReloadableFlag(quic_disable_version_rfcv1) || disable_version_draft_29_ != GetQuicReloadableFlag(quic_disable_version_draft_29) || - disable_version_t051_ != - GetQuicReloadableFlag(quic_disable_version_t051) || disable_version_q050_ != GetQuicReloadableFlag(quic_disable_version_q050) || disable_version_q046_ != GetQuicReloadableFlag(quic_disable_version_q046) || disable_version_q043_ != GetQuicReloadableFlag(quic_disable_version_q043)) { - enable_version_rfcv1_ = GetQuicReloadableFlag(quic_enable_version_rfcv1); + disable_version_rfcv1_ = GetQuicReloadableFlag(quic_disable_version_rfcv1); disable_version_draft_29_ = GetQuicReloadableFlag(quic_disable_version_draft_29); - disable_version_t051_ = GetQuicReloadableFlag(quic_disable_version_t051); disable_version_q050_ = GetQuicReloadableFlag(quic_disable_version_q050); disable_version_q046_ = GetQuicReloadableFlag(quic_disable_version_q046); disable_version_q043_ = GetQuicReloadableFlag(quic_disable_version_q043); @@ -76,6 +68,7 @@ void QuicVersionManager::MaybeRefilterSupportedVersions() { void QuicVersionManager::RefilterSupportedVersions() { filtered_supported_versions_ = FilterSupportedVersions(allowed_supported_versions_); + filtered_supported_versions_with_http3_.clear(); filtered_supported_versions_with_quic_crypto_.clear(); filtered_transport_versions_.clear(); filtered_supported_alpns_.clear(); @@ -86,6 +79,9 @@ void QuicVersionManager::RefilterSupportedVersions() { transport_version) == filtered_transport_versions_.end()) { filtered_transport_versions_.push_back(transport_version); } + if (version.UsesHttp3()) { + filtered_supported_versions_with_http3_.push_back(version); + } if (version.handshake_protocol == PROTOCOL_QUIC_CRYPTO) { filtered_supported_versions_with_quic_crypto_.push_back(version); } diff --git a/gquiche/quic/core/quic_version_manager.h b/gquiche/quic/core/quic_version_manager.h index 9373fbc9..6c8169b3 100644 --- a/gquiche/quic/core/quic_version_manager.h +++ b/gquiche/quic/core/quic_version_manager.h @@ -22,6 +22,9 @@ class QUIC_EXPORT_PRIVATE QuicVersionManager { // as the versions passed to the constructor. const ParsedQuicVersionVector& GetSupportedVersions(); + // Returns currently supported versions using HTTP/3. + const ParsedQuicVersionVector& GetSupportedVersionsWithOnlyHttp3(); + // Returns currently supported versions using QUIC crypto. const ParsedQuicVersionVector& GetSupportedVersionsWithQuicCrypto(); @@ -33,41 +36,52 @@ class QUIC_EXPORT_PRIVATE QuicVersionManager { // If the value of any reloadable flag is different from the cached value, // re-filter |filtered_supported_versions_| and update the cached flag values. // Otherwise, does nothing. + // TODO(dschinazi): Make private when deprecating + // FLAGS_gfe2_restart_flag_quic_disable_old_alt_svc_format. void MaybeRefilterSupportedVersions(); // Refilters filtered_supported_versions_. virtual void RefilterSupportedVersions(); + // RefilterSupportedVersions() must be called before calling this method. + // TODO(dschinazi): Remove when deprecating + // FLAGS_gfe2_restart_flag_quic_disable_old_alt_svc_format. const QuicTransportVersionVector& filtered_transport_versions() const { return filtered_transport_versions_; } - // Mechanism for subclasses to add custom ALPNs to the supported list. - // Should be called in constructor and RefilterSupportedVersions. + // Subclasses may add custom ALPNs to the supported list by overriding + // RefilterSupportedVersions() to first call + // QuicVersionManager::RefilterSupportedVersions() then AddCustomAlpn(). + // Must not be called elsewhere. void AddCustomAlpn(const std::string& alpn); - bool disable_version_q050() const { return disable_version_q050_; } - private: // Cached value of reloadable flags. - // quic_enable_version_rfcv1 flag - bool enable_version_rfcv1_; + // quic_disable_version_rfcv1 flag + bool disable_version_rfcv1_ = true; // quic_disable_version_draft_29 flag - bool disable_version_draft_29_; - // quic_disable_version_t051 flag - bool disable_version_t051_; + bool disable_version_draft_29_ = true; // quic_disable_version_q050 flag - bool disable_version_q050_; + bool disable_version_q050_ = true; // quic_disable_version_q046 flag - bool disable_version_q046_; + bool disable_version_q046_ = true; // quic_disable_version_q043 flag - bool disable_version_q043_; + bool disable_version_q043_ = true; // The list of versions that may be supported. - ParsedQuicVersionVector allowed_supported_versions_; + const ParsedQuicVersionVector allowed_supported_versions_; + + // The following vectors are calculated from reloadable flags by + // RefilterSupportedVersions(). It is performed lazily when first needed, and + // after that, since the calculation is relatively expensive, only if the flag + // values change. + // This vector contains QUIC versions which are currently supported based on // flags. ParsedQuicVersionVector filtered_supported_versions_; + // Currently supported versions using HTTP/3. + ParsedQuicVersionVector filtered_supported_versions_with_http3_; // Currently supported versions using QUIC crypto. ParsedQuicVersionVector filtered_supported_versions_with_quic_crypto_; // This vector contains the transport versions from diff --git a/gquiche/quic/core/quic_version_manager_test.cc b/gquiche/quic/core/quic_version_manager_test.cc index 482d679c..bea4ce2b 100644 --- a/gquiche/quic/core/quic_version_manager_test.cc +++ b/gquiche/quic/core/quic_version_manager_test.cc @@ -18,14 +18,13 @@ namespace { class QuicVersionManagerTest : public QuicTest {}; TEST_F(QuicVersionManagerTest, QuicVersionManager) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync"); for (const ParsedQuicVersion& version : AllSupportedVersions()) { QuicEnableVersion(version); } QuicDisableVersion(ParsedQuicVersion::RFCv1()); QuicDisableVersion(ParsedQuicVersion::Draft29()); - QuicDisableVersion(ParsedQuicVersion::T051()); QuicVersionManager manager(AllSupportedVersions()); ParsedQuicVersionVector expected_parsed_versions; @@ -37,6 +36,7 @@ TEST_F(QuicVersionManagerTest, QuicVersionManager) { EXPECT_EQ(FilterSupportedVersions(AllSupportedVersions()), manager.GetSupportedVersions()); + EXPECT_TRUE(manager.GetSupportedVersionsWithOnlyHttp3().empty()); EXPECT_EQ(CurrentSupportedVersionsWithQuicCrypto(), manager.GetSupportedVersionsWithQuicCrypto()); EXPECT_THAT(manager.GetSupportedAlpns(), @@ -44,31 +44,37 @@ TEST_F(QuicVersionManagerTest, QuicVersionManager) { int offset = 0; QuicEnableVersion(ParsedQuicVersion::Draft29()); - expected_parsed_versions.insert(expected_parsed_versions.begin() + offset, + expected_parsed_versions.insert(expected_parsed_versions.begin(), ParsedQuicVersion::Draft29()); EXPECT_EQ(expected_parsed_versions, manager.GetSupportedVersions()); EXPECT_EQ(expected_parsed_versions.size() - 1 - offset, manager.GetSupportedVersionsWithQuicCrypto().size()); EXPECT_EQ(FilterSupportedVersions(AllSupportedVersions()), manager.GetSupportedVersions()); + EXPECT_EQ(1u, manager.GetSupportedVersionsWithOnlyHttp3().size()); + EXPECT_EQ(CurrentSupportedHttp3Versions(), + manager.GetSupportedVersionsWithOnlyHttp3()); EXPECT_EQ(CurrentSupportedVersionsWithQuicCrypto(), manager.GetSupportedVersionsWithQuicCrypto()); EXPECT_THAT(manager.GetSupportedAlpns(), ElementsAre("h3-29", "h3-Q050", "h3-Q046", "h3-Q043")); offset++; - QuicEnableVersion(ParsedQuicVersion::T051()); - expected_parsed_versions.insert(expected_parsed_versions.begin() + offset, - ParsedQuicVersion::T051()); + QuicEnableVersion(ParsedQuicVersion::RFCv1()); + expected_parsed_versions.insert(expected_parsed_versions.begin(), + ParsedQuicVersion::RFCv1()); EXPECT_EQ(expected_parsed_versions, manager.GetSupportedVersions()); EXPECT_EQ(expected_parsed_versions.size() - 1 - offset, manager.GetSupportedVersionsWithQuicCrypto().size()); EXPECT_EQ(FilterSupportedVersions(AllSupportedVersions()), manager.GetSupportedVersions()); + EXPECT_EQ(2u, manager.GetSupportedVersionsWithOnlyHttp3().size()); + EXPECT_EQ(CurrentSupportedHttp3Versions(), + manager.GetSupportedVersionsWithOnlyHttp3()); EXPECT_EQ(CurrentSupportedVersionsWithQuicCrypto(), manager.GetSupportedVersionsWithQuicCrypto()); EXPECT_THAT(manager.GetSupportedAlpns(), - ElementsAre("h3-29", "h3-T051", "h3-Q050", "h3-Q046", "h3-Q043")); + ElementsAre("h3", "h3-29", "h3-Q050", "h3-Q046", "h3-Q043")); } } // namespace diff --git a/gquiche/quic/core/quic_versions.cc b/gquiche/quic/core/quic_versions.cc index 2f366342..f79ae1bb 100644 --- a/gquiche/quic/core/quic_versions.cc +++ b/gquiche/quic/core/quic_versions.cc @@ -17,8 +17,8 @@ #include "gquiche/quic/platform/api/quic_flag_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/quiche_endian.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace { @@ -42,16 +42,14 @@ QuicVersionLabel CreateRandomVersionLabelForNegotiation() { } void SetVersionFlag(const ParsedQuicVersion& version, bool should_enable) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync"); const bool enable = should_enable; const bool disable = !should_enable; if (version == ParsedQuicVersion::RFCv1()) { - SetQuicReloadableFlag(quic_enable_version_rfcv1, enable); + SetQuicReloadableFlag(quic_disable_version_rfcv1, disable); } else if (version == ParsedQuicVersion::Draft29()) { SetQuicReloadableFlag(quic_disable_version_draft_29, disable); - } else if (version == ParsedQuicVersion::T051()) { - SetQuicReloadableFlag(quic_disable_version_t051, disable); } else if (version == ParsedQuicVersion::Q050()) { SetQuicReloadableFlag(quic_disable_version_q050, disable); } else if (version == ParsedQuicVersion::Q046()) { @@ -215,14 +213,12 @@ std::ostream& operator<<(std::ostream& os, } QuicVersionLabel CreateQuicVersionLabel(ParsedQuicVersion parsed_version) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync"); if (parsed_version == ParsedQuicVersion::RFCv1()) { return MakeVersionLabel(0x00, 0x00, 0x00, 0x01); } else if (parsed_version == ParsedQuicVersion::Draft29()) { return MakeVersionLabel(0xff, 0x00, 0x00, 29); - } else if (parsed_version == ParsedQuicVersion::T051()) { - return MakeVersionLabel('T', '0', '5', '1'); } else if (parsed_version == ParsedQuicVersion::Q050()) { return MakeVersionLabel('Q', '0', '5', '0'); } else if (parsed_version == ParsedQuicVersion::Q046()) { @@ -297,6 +293,18 @@ ParsedQuicVersionVector CurrentSupportedVersionsWithTls() { return versions; } +ParsedQuicVersionVector CurrentSupportedHttp3Versions() { + ParsedQuicVersionVector versions; + for (const ParsedQuicVersion& version : CurrentSupportedVersions()) { + if (version.UsesHttp3()) { + versions.push_back(version); + } + } + QUIC_BUG_IF(no_version_uses_http3, versions.empty()) + << "No version speaking Http3 found."; + return versions; +} + ParsedQuicVersion ParseQuicVersionLabel(QuicVersionLabel version_label) { for (const ParsedQuicVersion& version : AllSupportedVersions()) { if (version_label == CreateQuicVersionLabel(version)) { @@ -309,6 +317,18 @@ ParsedQuicVersion ParseQuicVersionLabel(QuicVersionLabel version_label) { return UnsupportedQuicVersion(); } +ParsedQuicVersionVector ParseQuicVersionLabelVector( + const QuicVersionLabelVector& version_labels) { + ParsedQuicVersionVector parsed_versions; + for (const QuicVersionLabel& version_label : version_labels) { + ParsedQuicVersion parsed_version = ParseQuicVersionLabel(version_label); + if (parsed_version.IsKnown()) { + parsed_versions.push_back(parsed_version); + } + } + return parsed_versions; +} + ParsedQuicVersion ParseQuicVersionString(absl::string_view version_string) { if (version_string.empty()) { return UnsupportedQuicVersion(); @@ -395,17 +415,13 @@ ParsedQuicVersionVector FilterSupportedVersions( filtered_versions.reserve(versions.size()); for (const ParsedQuicVersion& version : versions) { if (version == ParsedQuicVersion::RFCv1()) { - if (GetQuicReloadableFlag(quic_enable_version_rfcv1)) { + if (!GetQuicReloadableFlag(quic_disable_version_rfcv1)) { filtered_versions.push_back(version); } } else if (version == ParsedQuicVersion::Draft29()) { if (!GetQuicReloadableFlag(quic_disable_version_draft_29)) { filtered_versions.push_back(version); } - } else if (version == ParsedQuicVersion::T051()) { - if (!GetQuicReloadableFlag(quic_disable_version_t051)) { - filtered_versions.push_back(version); - } } else if (version == ParsedQuicVersion::Q050()) { if (!GetQuicReloadableFlag(quic_disable_version_q050)) { filtered_versions.push_back(version); @@ -472,7 +488,6 @@ std::string QuicVersionToString(QuicTransportVersion transport_version) { RETURN_STRING_LITERAL(QUIC_VERSION_43); RETURN_STRING_LITERAL(QUIC_VERSION_46); RETURN_STRING_LITERAL(QUIC_VERSION_50); - RETURN_STRING_LITERAL(QUIC_VERSION_51); RETURN_STRING_LITERAL(QUIC_VERSION_IETF_DRAFT_29); RETURN_STRING_LITERAL(QUIC_VERSION_IETF_RFC_V1); RETURN_STRING_LITERAL(QUIC_VERSION_UNSUPPORTED); @@ -493,7 +508,7 @@ std::string HandshakeProtocolToString(HandshakeProtocol handshake_protocol) { } std::string ParsedQuicVersionToString(ParsedQuicVersion version) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync"); if (version == UnsupportedQuicVersion()) { return "0"; @@ -599,7 +614,7 @@ std::string AlpnForVersion(ParsedQuicVersion parsed_version) { void QuicVersionInitializeSupportForIetfDraft() { // Enable necessary flags. - SetQuicReloadableFlag(quic_fix_key_update_on_first_packet, true); + SetQuicReloadableFlag(quic_version_information, true); } void QuicEnableVersion(const ParsedQuicVersion& version) { diff --git a/gquiche/quic/core/quic_versions.h b/gquiche/quic/core/quic_versions.h index b9076aab..702d3e52 100644 --- a/gquiche/quic/core/quic_versions.h +++ b/gquiche/quic/core/quic_versions.h @@ -118,14 +118,14 @@ enum QuicTransportVersion { // Version 49 added client connection IDs, long header lengths, and the IETF // header format from draft-ietf-quic-invariants-06 QUIC_VERSION_50 = 50, // Header protection and initial obfuscators. - QUIC_VERSION_51 = 51, // draft-29 features but with GoogleQUIC frames. + // Number 51 was T051 which used draft-29 features but with GoogleQUIC frames. // Number 70 used to represent draft-ietf-quic-transport-25. // Number 71 used to represent draft-ietf-quic-transport-27. // Number 72 used to represent draft-ietf-quic-transport-28. QUIC_VERSION_IETF_DRAFT_29 = 73, // draft-ietf-quic-transport-29. - QUIC_VERSION_IETF_RFC_V1 = 80, // Not-yet-published RFC. + QUIC_VERSION_IETF_RFC_V1 = 80, // RFC 9000. // Version 99 was a dumping ground for IETF QUIC changes which were not yet - // yet ready for production between 2018-02 and 2020-02. + // ready for production between 2018-02 and 2020-02. // QUIC_VERSION_RESERVED_FOR_NEGOTIATION is sent over the wire as ?a?a?a?a // which is part of a range reserved by the IETF for version negotiation @@ -173,7 +173,6 @@ QUIC_EXPORT_PRIVATE constexpr bool ParsedQuicVersionIsValid( constexpr QuicTransportVersion valid_transport_versions[] = { QUIC_VERSION_IETF_RFC_V1, QUIC_VERSION_IETF_DRAFT_29, - QUIC_VERSION_51, QUIC_VERSION_50, QUIC_VERSION_46, QUIC_VERSION_43, @@ -195,7 +194,6 @@ QUIC_EXPORT_PRIVATE constexpr bool ParsedQuicVersionIsValid( case PROTOCOL_QUIC_CRYPTO: return transport_version != QUIC_VERSION_UNSUPPORTED && transport_version != QUIC_VERSION_RESERVED_FOR_NEGOTIATION && - transport_version != QUIC_VERSION_51 && transport_version != QUIC_VERSION_IETF_DRAFT_29 && transport_version != QUIC_VERSION_IETF_RFC_V1; case PROTOCOL_TLS1_3: @@ -255,10 +253,6 @@ struct QUIC_EXPORT_PRIVATE ParsedQuicVersion { return ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_IETF_DRAFT_29); } - static constexpr ParsedQuicVersion T051() { - return ParsedQuicVersion(PROTOCOL_TLS1_3, QUIC_VERSION_51); - } - static constexpr ParsedQuicVersion Q050() { return ParsedQuicVersion(PROTOCOL_QUIC_CRYPTO, QUIC_VERSION_50); } @@ -400,11 +394,11 @@ constexpr std::array SupportedHandshakeProtocols() { return {PROTOCOL_TLS1_3, PROTOCOL_QUIC_CRYPTO}; } -constexpr std::array SupportedVersions() { +constexpr std::array SupportedVersions() { return { ParsedQuicVersion::RFCv1(), ParsedQuicVersion::Draft29(), - ParsedQuicVersion::T051(), ParsedQuicVersion::Q050(), - ParsedQuicVersion::Q046(), ParsedQuicVersion::Q043(), + ParsedQuicVersion::Q050(), ParsedQuicVersion::Q046(), + ParsedQuicVersion::Q043(), }; } @@ -446,6 +440,10 @@ QUIC_EXPORT_PRIVATE ParsedQuicVersionVector AllSupportedVersionsWithTls(); // PROTOCOL_TLS1_3. QUIC_EXPORT_PRIVATE ParsedQuicVersionVector CurrentSupportedVersionsWithTls(); +// Returns a subset of CurrentSupportedVersions() using HTTP/3 at the HTTP +// layer. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector CurrentSupportedHttp3Versions(); + // Returns QUIC version of |index| in result of |versions|. Returns // UnsupportedQuicVersion() if |index| is out of bounds. QUIC_EXPORT_PRIVATE ParsedQuicVersionVector @@ -458,6 +456,11 @@ ParsedVersionOfIndex(const ParsedQuicVersionVector& versions, int index); QUIC_EXPORT_PRIVATE ParsedQuicVersion ParseQuicVersionLabel(QuicVersionLabel version_label); +// Helper function that translates from a QuicVersionLabelVector to a +// ParsedQuicVersionVector. +QUIC_EXPORT_PRIVATE ParsedQuicVersionVector +ParseQuicVersionLabelVector(const QuicVersionLabelVector& version_labels); + // Parses a QUIC version string such as "Q043" or "T051". Also supports parsing // ALPN such as "h3-29" or "h3-Q050". For PROTOCOL_QUIC_CRYPTO versions, also // supports parsing numbers such as "46". @@ -562,7 +565,7 @@ QUIC_EXPORT_PRIVATE constexpr bool VersionSupportsMessageFrames( // * GOAWAY is moved to HTTP layer. QUIC_EXPORT_PRIVATE constexpr bool VersionUsesHttp3( QuicTransportVersion transport_version) { - return transport_version > QUIC_VERSION_51; + return transport_version >= QUIC_VERSION_IETF_DRAFT_29; } // Returns whether the transport_version supports the variable length integer diff --git a/gquiche/quic/core/quic_versions_test.cc b/gquiche/quic/core/quic_versions_test.cc index 13eb905c..04c62549 100644 --- a/gquiche/quic/core/quic_versions_test.cc +++ b/gquiche/quic/core/quic_versions_test.cc @@ -117,10 +117,16 @@ TEST_F(QuicVersionsTest, ParseQuicVersionLabel) { ParseQuicVersionLabel(MakeVersionLabel('Q', '0', '4', '6'))); EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionLabel(MakeVersionLabel('Q', '0', '5', '0'))); - EXPECT_EQ(ParsedQuicVersion::T051(), - ParseQuicVersionLabel(MakeVersionLabel('T', '0', '5', '1'))); EXPECT_EQ(ParsedQuicVersion::Draft29(), ParseQuicVersionLabel(MakeVersionLabel(0xff, 0x00, 0x00, 0x1d))); + EXPECT_EQ(ParsedQuicVersion::RFCv1(), + ParseQuicVersionLabel(MakeVersionLabel(0x00, 0x00, 0x00, 0x01))); + EXPECT_EQ((ParsedQuicVersionVector{ParsedQuicVersion::RFCv1(), + ParsedQuicVersion::Draft29()}), + ParseQuicVersionLabelVector(QuicVersionLabelVector{ + MakeVersionLabel(0x00, 0x00, 0x00, 0x01), + MakeVersionLabel(0xaa, 0xaa, 0xaa, 0xaa), + MakeVersionLabel(0xff, 0x00, 0x00, 0x1d)})); } TEST_F(QuicVersionsTest, ParseQuicVersionString) { @@ -132,8 +138,6 @@ TEST_F(QuicVersionsTest, ParseQuicVersionString) { EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionString("Q050")); EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionString("50")); EXPECT_EQ(ParsedQuicVersion::Q050(), ParseQuicVersionString("h3-Q050")); - EXPECT_EQ(ParsedQuicVersion::T051(), ParseQuicVersionString("T051")); - EXPECT_EQ(ParsedQuicVersion::T051(), ParseQuicVersionString("h3-T051")); EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionString("")); EXPECT_EQ(UnsupportedQuicVersion(), ParseQuicVersionString("Q 46")); @@ -154,7 +158,6 @@ TEST_F(QuicVersionsTest, ParseQuicVersionString) { TEST_F(QuicVersionsTest, ParseQuicVersionVectorString) { ParsedQuicVersion version_q046 = ParsedQuicVersion::Q046(); ParsedQuicVersion version_q050 = ParsedQuicVersion::Q050(); - ParsedQuicVersion version_t051 = ParsedQuicVersion::T051(); ParsedQuicVersion version_draft_29 = ParsedQuicVersion::Draft29(); EXPECT_THAT(ParseQuicVersionVectorString(""), IsEmpty()); @@ -163,36 +166,22 @@ TEST_F(QuicVersionsTest, ParseQuicVersionVectorString) { ElementsAre(version_q050)); EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050"), ElementsAre(version_q050)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051"), - ElementsAre(version_t051)); - - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051, h3-29"), - ElementsAre(version_t051, version_draft_29)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-29,h3-T051,h3-29"), - ElementsAre(version_draft_29, version_t051)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-29,h3-T051, h3-29"), - ElementsAre(version_draft_29, version_t051)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051,h3-29"), - ElementsAre(version_t051, version_draft_29)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-29,h3-T051"), - ElementsAre(version_draft_29, version_t051)); - - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051,50"), - ElementsAre(version_t051, version_q050)); - - EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050, h3-T051"), - ElementsAre(version_q050, version_t051)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051, h3-Q050"), - ElementsAre(version_t051, version_q050)); - EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50,h3-T051"), - ElementsAre(version_q050, version_t051)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051,QUIC_VERSION_50"), - ElementsAre(version_t051, version_q050)); - EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50, h3-T051"), - ElementsAre(version_q050, version_t051)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051, QUIC_VERSION_50"), - ElementsAre(version_t051, version_q050)); - + EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050, h3-29"), + ElementsAre(version_q050, version_draft_29)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29,h3-Q050,h3-29"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29,h3-Q050, h3-29"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29, h3-Q050"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50,h3-29"), + ElementsAre(version_q050, version_draft_29)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29,QUIC_VERSION_50"), + ElementsAre(version_draft_29, version_q050)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50, h3-29"), + ElementsAre(version_q050, version_draft_29)); + EXPECT_THAT(ParseQuicVersionVectorString("h3-29, QUIC_VERSION_50"), + ElementsAre(version_draft_29, version_q050)); EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50,QUIC_VERSION_46"), ElementsAre(version_q050, version_q046)); EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_46,QUIC_VERSION_50"), @@ -203,15 +192,13 @@ TEST_F(QuicVersionsTest, ParseQuicVersionVectorString) { ElementsAre(version_q050)); EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050, h3-Q050"), ElementsAre(version_q050)); - EXPECT_THAT(ParseQuicVersionVectorString("h3-T051, h3-T051"), - ElementsAre(version_t051)); EXPECT_THAT(ParseQuicVersionVectorString("h3-Q050, QUIC_VERSION_50"), ElementsAre(version_q050)); EXPECT_THAT(ParseQuicVersionVectorString( "QUIC_VERSION_50, h3-Q050, QUIC_VERSION_50, h3-Q050"), ElementsAre(version_q050)); - EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50, h3-T051, h3-Q050"), - ElementsAre(version_q050, version_t051)); + EXPECT_THAT(ParseQuicVersionVectorString("QUIC_VERSION_50, h3-29, h3-Q050"), + ElementsAre(version_q050, version_draft_29)); EXPECT_THAT(ParseQuicVersionVectorString("99"), IsEmpty()); EXPECT_THAT(ParseQuicVersionVectorString("70"), IsEmpty()); @@ -227,10 +214,10 @@ TEST_F(QuicVersionsTest, CreateQuicVersionLabel) { CreateQuicVersionLabel(ParsedQuicVersion::Q046())); EXPECT_EQ(MakeVersionLabel('Q', '0', '5', '0'), CreateQuicVersionLabel(ParsedQuicVersion::Q050())); - - // Test a TLS version: - EXPECT_EQ(MakeVersionLabel('T', '0', '5', '1'), - CreateQuicVersionLabel(ParsedQuicVersion::T051())); + EXPECT_EQ(MakeVersionLabel(0xff, 0x00, 0x00, 0x1d), + CreateQuicVersionLabel(ParsedQuicVersion::Draft29())); + EXPECT_EQ(MakeVersionLabel(0x00, 0x00, 0x00, 0x01), + CreateQuicVersionLabel(ParsedQuicVersion::RFCv1())); // Make sure the negotiation reserved version is in the IETF reserved space. EXPECT_EQ( @@ -303,8 +290,8 @@ TEST_F(QuicVersionsTest, ParsedQuicVersionToString) { EXPECT_EQ("Q043", ParsedQuicVersionToString(ParsedQuicVersion::Q043())); EXPECT_EQ("Q046", ParsedQuicVersionToString(ParsedQuicVersion::Q046())); EXPECT_EQ("Q050", ParsedQuicVersionToString(ParsedQuicVersion::Q050())); - EXPECT_EQ("T051", ParsedQuicVersionToString(ParsedQuicVersion::T051())); EXPECT_EQ("draft29", ParsedQuicVersionToString(ParsedQuicVersion::Draft29())); + EXPECT_EQ("RFCv1", ParsedQuicVersionToString(ParsedQuicVersion::RFCv1())); ParsedQuicVersionVector versions_vector = {ParsedQuicVersion::Q043()}; EXPECT_EQ("Q043", ParsedQuicVersionVectorToString(versions_vector)); @@ -371,23 +358,21 @@ TEST_F(QuicVersionsTest, LookUpParsedVersionByIndex) { // yet a typo was made in doing the #defines and it was caught // only in some test far removed from here... Better safe than sorry. TEST_F(QuicVersionsTest, CheckTransportVersionNumbersForTypos) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync"); EXPECT_EQ(QUIC_VERSION_43, 43); EXPECT_EQ(QUIC_VERSION_46, 46); EXPECT_EQ(QUIC_VERSION_50, 50); - EXPECT_EQ(QUIC_VERSION_51, 51); EXPECT_EQ(QUIC_VERSION_IETF_DRAFT_29, 73); EXPECT_EQ(QUIC_VERSION_IETF_RFC_V1, 80); } TEST_F(QuicVersionsTest, AlpnForVersion) { - static_assert(SupportedVersions().size() == 6u, + static_assert(SupportedVersions().size() == 5u, "Supported versions out of sync"); EXPECT_EQ("h3-Q043", AlpnForVersion(ParsedQuicVersion::Q043())); EXPECT_EQ("h3-Q046", AlpnForVersion(ParsedQuicVersion::Q046())); EXPECT_EQ("h3-Q050", AlpnForVersion(ParsedQuicVersion::Q050())); - EXPECT_EQ("h3-T051", AlpnForVersion(ParsedQuicVersion::T051())); EXPECT_EQ("h3-29", AlpnForVersion(ParsedQuicVersion::Draft29())); EXPECT_EQ("h3", AlpnForVersion(ParsedQuicVersion::RFCv1())); } @@ -445,6 +430,25 @@ TEST_F(QuicVersionsTest, SupportedVersionsAllDistinct) { } } +TEST_F(QuicVersionsTest, CurrentSupportedHttp3Versions) { + ParsedQuicVersionVector h3_versions = CurrentSupportedHttp3Versions(); + ParsedQuicVersionVector all_current_supported_versions = + CurrentSupportedVersions(); + for (auto& version : all_current_supported_versions) { + bool version_is_h3 = false; + for (auto& h3_version : h3_versions) { + if (version == h3_version) { + EXPECT_TRUE(version.UsesHttp3()); + version_is_h3 = true; + break; + } + } + if (!version_is_h3) { + EXPECT_FALSE(version.UsesHttp3()); + } + } +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/quic_write_blocked_list.h b/gquiche/quic/core/quic_write_blocked_list.h index 86ac6c64..d2e49e65 100644 --- a/gquiche/quic/core/quic_write_blocked_list.h +++ b/gquiche/quic/core/quic_write_blocked_list.h @@ -9,13 +9,13 @@ #include #include +#include "absl/container/inlined_vector.h" #include "gquiche/http2/core/priority_write_scheduler.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/quic/platform/api/quic_map_util.h" namespace quic { @@ -102,7 +102,7 @@ class QUIC_EXPORT_PRIVATE QuicWriteBlockedList { }; // Optimized for the typical case of 2 static streams per session. - using StreamsVector = QuicInlinedVector; + using StreamsVector = absl::InlinedVector; StreamsVector::const_iterator begin() const { return streams_.cbegin(); } diff --git a/gquiche/quic/core/stream_delegate_interface.h b/gquiche/quic/core/stream_delegate_interface.h index ce2c58ef..21a87c09 100644 --- a/gquiche/quic/core/stream_delegate_interface.h +++ b/gquiche/quic/core/stream_delegate_interface.h @@ -28,18 +28,13 @@ class QUIC_EXPORT_PRIVATE StreamDelegateInterface { virtual void OnStreamError(QuicErrorCode error_code, QuicIetfTransportErrorCodes ietf_error, std::string error_details) = 0; - // Called when the stream needs to write data. If |level| is present, the data - // will be written at the specified |level|. The data will be written - // at specified transmission |type|. - // TODO(fayang): Change absl::optional to EncryptionLevel - // when deprecating quic_use_write_or_buffer_data_at_level. - virtual QuicConsumedData WritevData( - QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, - TransmissionType type, - absl::optional level) = 0; + // Called when the stream needs to write data at specified |level| and + // transmission |type|. + virtual QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, + StreamSendingState state, + TransmissionType type, + EncryptionLevel level) = 0; // Called to write crypto data. virtual size_t SendCryptoData(EncryptionLevel level, size_t write_length, diff --git a/gquiche/quic/core/tls_chlo_extractor.cc b/gquiche/quic/core/tls_chlo_extractor.cc index ee074510..3271830d 100644 --- a/gquiche/quic/core/tls_chlo_extractor.cc +++ b/gquiche/quic/core/tls_chlo_extractor.cc @@ -17,10 +17,19 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { +namespace { +bool HasExtension(const SSL_CLIENT_HELLO* client_hello, uint16_t extension) { + const uint8_t* unused_extension_bytes; + size_t unused_extension_len; + return 1 == SSL_early_callback_ctx_extension_get(client_hello, extension, + &unused_extension_bytes, + &unused_extension_len); +} +} // namespace + TlsChloExtractor::TlsChloExtractor() : crypto_stream_sequencer_(this), state_(State::kInitial), @@ -281,6 +290,11 @@ void TlsChloExtractor::HandleParsedChlo(const SSL_CLIENT_HELLO* client_hello) { if (server_name) { server_name_ = std::string(server_name); } + + resumption_attempted_ = + HasExtension(client_hello, TLSEXT_TYPE_pre_shared_key); + early_data_attempted_ = HasExtension(client_hello, TLSEXT_TYPE_early_data); + const uint8_t* alpn_data; size_t alpn_len; int rv = SSL_early_callback_ctx_extension_get( diff --git a/gquiche/quic/core/tls_chlo_extractor.h b/gquiche/quic/core/tls_chlo_extractor.h index bbe4922c..43517bc0 100644 --- a/gquiche/quic/core/tls_chlo_extractor.h +++ b/gquiche/quic/core/tls_chlo_extractor.h @@ -45,6 +45,8 @@ class QUIC_NO_EXPORT TlsChloExtractor State state() const { return state_; } std::vector alpns() const { return alpns_; } std::string server_name() const { return server_name_; } + bool resumption_attempted() const { return resumption_attempted_; } + bool early_data_attempted() const { return early_data_attempted_; } // Converts |state| to a human-readable string suitable for logging. static std::string StateToString(State state); @@ -177,7 +179,7 @@ class QUIC_NO_EXPORT TlsChloExtractor void OnDataAvailable() override; void OnFinRead() override {} void AddBytesConsumed(QuicByteCount /*bytes*/) override {} - void Reset(QuicRstStreamErrorCode /*error*/) override {} + void ResetWithError(QuicResetStreamError /*error*/) override {} void OnUnrecoverableError(QuicErrorCode error, const std::string& details) override; void OnUnrecoverableError(QuicErrorCode error, @@ -246,6 +248,12 @@ class QUIC_NO_EXPORT TlsChloExtractor std::vector alpns_; // SNI parsed from the CHLO. std::string server_name_; + // Whether resumption is attempted from the CHLO, indicated by the + // 'pre_shared_key' TLS extension. + bool resumption_attempted_ = false; + // Whether early data is attempted from the CHLO, indicated by the + // 'early_data' TLS extension. + bool early_data_attempted_ = false; }; // Convenience method to facilitate logging TlsChloExtractor::State. diff --git a/gquiche/quic/core/tls_chlo_extractor_test.cc b/gquiche/quic/core/tls_chlo_extractor_test.cc index c1c02cde..35a19614 100644 --- a/gquiche/quic/core/tls_chlo_extractor_test.cc +++ b/gquiche/quic/core/tls_chlo_extractor_test.cc @@ -3,26 +3,87 @@ // found in the LICENSE file. #include "gquiche/quic/core/tls_chlo_extractor.h" + #include +#include "openssl/ssl.h" #include "gquiche/quic/core/http/quic_spdy_client_session.h" #include "gquiche/quic/core/quic_connection.h" #include "gquiche/quic/core/quic_packet_writer_wrapper.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/crypto_test_utils.h" #include "gquiche/quic/test_tools/first_flight.h" #include "gquiche/quic/test_tools/quic_test_utils.h" +#include "gquiche/quic/test_tools/simple_session_cache.h" namespace quic { namespace test { namespace { +using testing::_; +using testing::AnyNumber; + class TlsChloExtractorTest : public QuicTestWithParam { protected: - TlsChloExtractorTest() : version_(GetParam()) {} + TlsChloExtractorTest() : version_(GetParam()), server_id_(TestServerId()) {} void Initialize() { packets_ = GetFirstFlightOfPackets(version_, config_); } + void Initialize(std::unique_ptr crypto_config) { + packets_ = GetFirstFlightOfPackets(version_, config_, TestConnectionId(), + EmptyQuicConnectionId(), + std::move(crypto_config)); + } + + // Perform a full handshake in order to insert a SSL_SESSION into + // crypto_config->session_cache(), which can be used by a TLS resumption. + void PerformFullHandshake(QuicCryptoClientConfig* crypto_config) const { + ASSERT_NE(crypto_config->session_cache(), nullptr); + MockQuicConnectionHelper client_helper, server_helper; + MockAlarmFactory alarm_factory; + ParsedQuicVersionVector supported_versions = {version_}; + PacketSavingConnection* client_connection = + new PacketSavingConnection(&client_helper, &alarm_factory, + Perspective::IS_CLIENT, supported_versions); + // Advance the time, because timers do not like uninitialized times. + client_connection->AdvanceTime(QuicTime::Delta::FromSeconds(1)); + QuicClientPushPromiseIndex push_promise_index; + QuicSpdyClientSession client_session(config_, supported_versions, + client_connection, server_id_, + crypto_config, &push_promise_index); + client_session.Initialize(); + + std::unique_ptr server_crypto_config = + crypto_test_utils::CryptoServerConfigForTesting(); + QuicConfig server_config; + + EXPECT_CALL(*client_connection, SendCryptoData(_, _, _)).Times(AnyNumber()); + client_session.GetMutableCryptoStream()->CryptoConnect(); + + crypto_test_utils::HandshakeWithFakeServer( + &server_config, server_crypto_config.get(), &server_helper, + &alarm_factory, client_connection, + client_session.GetMutableCryptoStream(), + AlpnForVersion(client_connection->version())); + + // For some reason, the test client can not receive the server settings and + // the SSL_SESSION will not be inserted to client's session_cache. We create + // a dummy settings and call SetServerApplicationStateForResumption manually + // to ensure the SSL_SESSION is cached. + // TODO(wub): Fix crypto_test_utils::HandshakeWithFakeServer to make sure a + // SSL_SESSION is cached at the client, and remove the rest of the function. + SettingsFrame server_settings; + server_settings.values[SETTINGS_QPACK_MAX_TABLE_CAPACITY] = + kDefaultQpackMaxDynamicTableCapacity; + std::unique_ptr buffer; + uint64_t length = + HttpEncoder::SerializeSettingsFrame(server_settings, &buffer); + client_session.GetMutableCryptoStream() + ->SetServerApplicationStateForResumption( + std::make_unique(buffer.get(), + buffer.get() + length)); + } void IngestPackets() { for (const std::unique_ptr& packet : packets_) { @@ -30,16 +91,14 @@ class TlsChloExtractorTest : public QuicTestWithParam { QuicSocketAddress(TestPeerIPAddress(), kTestPort), QuicSocketAddress(TestPeerIPAddress(), kTestPort), *packet); std::string detailed_error; - bool retry_token_present; - absl::string_view retry_token; + absl::optional retry_token; const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( *packet, /*expected_destination_connection_id_length=*/0, &packet_info.form, &packet_info.long_packet_type, &packet_info.version_flag, &packet_info.use_length_prefix, &packet_info.version_label, &packet_info.version, &packet_info.destination_connection_id, - &packet_info.source_connection_id, &retry_token_present, &retry_token, - &detailed_error); + &packet_info.source_connection_id, &retry_token, &detailed_error); ASSERT_THAT(error, IsQuicNoError()) << detailed_error; tls_chlo_extractor_.IngestPacket(packet_info.version, packet_info.packet); } @@ -64,6 +123,7 @@ class TlsChloExtractorTest : public QuicTestWithParam { } ParsedQuicVersion version_; + QuicServerId server_id_; TlsChloExtractor tls_chlo_extractor_; QuicConfig config_; std::vector> packets_; @@ -81,6 +141,42 @@ TEST_P(TlsChloExtractorTest, Simple) { ValidateChloDetails(); EXPECT_EQ(tls_chlo_extractor_.state(), TlsChloExtractor::State::kParsedFullSinglePacketChlo); + EXPECT_FALSE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted()); +} + +TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ResumptionOnly) { + auto crypto_client_config = std::make_unique( + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique()); + PerformFullHandshake(crypto_client_config.get()); + + SSL_CTX_set_early_data_enabled(crypto_client_config->ssl_ctx(), 0); + Initialize(std::move(crypto_client_config)); + EXPECT_GE(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullSinglePacketChlo); + EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_FALSE(tls_chlo_extractor_.early_data_attempted()); +} + +TEST_P(TlsChloExtractorTest, TlsExtentionInfo_ZeroRtt) { + auto crypto_client_config = std::make_unique( + crypto_test_utils::ProofVerifierForTesting(), + std::make_unique()); + PerformFullHandshake(crypto_client_config.get()); + + IncreaseSizeOfChlo(); + Initialize(std::move(crypto_client_config)); + EXPECT_GE(packets_.size(), 1u); + IngestPackets(); + ValidateChloDetails(); + EXPECT_EQ(tls_chlo_extractor_.state(), + TlsChloExtractor::State::kParsedFullMultiPacketChlo); + EXPECT_TRUE(tls_chlo_extractor_.resumption_attempted()); + EXPECT_TRUE(tls_chlo_extractor_.early_data_attempted()); } TEST_P(TlsChloExtractorTest, MultiPacket) { @@ -127,15 +223,14 @@ TEST_P(TlsChloExtractorTest, MoveAssignmentBetweenPackets) { QuicSocketAddress(TestPeerIPAddress(), kTestPort), QuicSocketAddress(TestPeerIPAddress(), kTestPort), *packets_[0]); std::string detailed_error; - bool retry_token_present; - absl::string_view retry_token; + absl::optional retry_token; const QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( *packets_[0], /*expected_destination_connection_id_length=*/0, &packet_info.form, &packet_info.long_packet_type, &packet_info.version_flag, &packet_info.use_length_prefix, &packet_info.version_label, &packet_info.version, &packet_info.destination_connection_id, &packet_info.source_connection_id, - &retry_token_present, &retry_token, &detailed_error); + &retry_token, &detailed_error); ASSERT_THAT(error, IsQuicNoError()) << detailed_error; other_extractor.IngestPacket(packet_info.version, packet_info.packet); // Remove the first packet from the list. diff --git a/gquiche/quic/core/tls_client_handshaker.cc b/gquiche/quic/core/tls_client_handshaker.cc index d8284f3b..c547ebb5 100644 --- a/gquiche/quic/core/tls_client_handshaker.cc +++ b/gquiche/quic/core/tls_client_handshaker.cc @@ -17,7 +17,7 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_hostname_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { @@ -41,14 +41,27 @@ TlsClientHandshaker::TlsClientHandshaker( crypto_negotiated_params_(new QuicCryptoNegotiatedParameters), has_application_state_(has_application_state), crypto_config_(crypto_config), - tls_connection_(crypto_config->ssl_ctx(), this) { - if (GetQuicReloadableFlag(quic_enable_token_based_address_validation)) { + tls_connection_(crypto_config->ssl_ctx(), this, session->GetSSLConfig()) { + if (!GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { std::string token = crypto_config->LookupOrCreate(server_id)->source_address_token(); if (!token.empty()) { session->SetSourceAddressTokenToSend(token); } } + if (crypto_config->tls_signature_algorithms().has_value()) { + SSL_set1_sigalgs_list(ssl(), + crypto_config->tls_signature_algorithms()->c_str()); + } + if (crypto_config->proof_source() != nullptr) { + const ClientProofSource::CertAndKey* cert_and_key = + crypto_config->proof_source()->GetCertAndKey(server_id.host()); + if (cert_and_key != nullptr) { + QUIC_DVLOG(1) << "Setting client cert and key for " << server_id.host(); + tls_connection_.SetCertChain(cert_and_key->chain->ToCryptoBuffers().value, + cert_and_key->private_key.private_key()); + } + } } TlsClientHandshaker::~TlsClientHandshaker() {} @@ -70,6 +83,17 @@ bool TlsClientHandshaker::CryptoConnect() { } SSL_set_quic_use_legacy_codepoint(ssl(), use_legacy_extension); + // TODO(b/193650832) Add SetFromConfig to QUIC handshakers and remove reliance + // on session pointer. + const bool permutes_tls_extensions = session()->permutes_tls_extensions(); + if (!permutes_tls_extensions) { + QUIC_DLOG(INFO) << "Disabling TLS extension permutation"; + } +#if BORINGSSL_API_VERSION >= 16 + // Ask BoringSSL to randomize the order of TLS extensions. + SSL_set_permute_extensions(ssl(), permutes_tls_extensions); +#endif // BORINGSSL_API_VERSION + // Set the SNI to send, if any. SSL_set_connect_state(ssl()); if (QUIC_DLOG_INFO_IS_ON() && @@ -98,10 +122,15 @@ bool TlsClientHandshaker::CryptoConnect() { // Set a session to resume, if there is one. if (session_cache_) { - cached_state_ = session_cache_->Lookup(server_id_, SSL_get_SSL_CTX(ssl())); + cached_state_ = session_cache_->Lookup( + server_id_, session()->GetClock()->WallNow(), SSL_get_SSL_CTX(ssl())); } if (cached_state_) { SSL_set_session(ssl(), cached_state_->tls_session.get()); + if (GetQuicReloadableFlag(quic_tls_use_token_in_session_cache) && + !cached_state_->token.empty()) { + session()->SetSourceAddressTokenToSend(cached_state_->token); + } } // Start the handshake. @@ -179,18 +208,16 @@ bool TlsClientHandshaker::SetAlpn() { } // Enable ALPS only for versions that use HTTP/3 frames. - if (enable_alps_) { - for (const std::string& alpn_string : alpns) { - ParsedQuicVersion version = ParseQuicVersionString(alpn_string); - if (!version.IsKnown() || !version.UsesHttp3()) { - continue; - } - if (SSL_add_application_settings( - ssl(), reinterpret_cast(alpn_string.data()), - alpn_string.size(), nullptr, /* settings_len = */ 0) != 1) { - QUIC_BUG(quic_bug_10576_7) << "Failed to enable ALPS."; - return false; - } + for (const std::string& alpn_string : alpns) { + ParsedQuicVersion version = ParseQuicVersionString(alpn_string); + if (!version.IsKnown() || !version.UsesHttp3()) { + continue; + } + if (SSL_add_application_settings( + ssl(), reinterpret_cast(alpn_string.data()), + alpn_string.size(), nullptr, /* settings_len = */ 0) != 1) { + QUIC_BUG(quic_bug_10576_7) << "Failed to enable ALPS."; + return false; } } @@ -201,8 +228,14 @@ bool TlsClientHandshaker::SetAlpn() { bool TlsClientHandshaker::SetTransportParameters() { TransportParameters params; params.perspective = Perspective::IS_CLIENT; - params.version = + params.legacy_version_information = + TransportParameters::LegacyVersionInformation(); + params.legacy_version_information.value().version = CreateQuicVersionLabel(session()->supported_versions().front()); + params.version_information = TransportParameters::VersionInformation(); + const QuicVersionLabel version = CreateQuicVersionLabel(session()->version()); + params.version_information.value().chosen_version = version; + params.version_information.value().other_versions.push_back(version); if (!handshaker_delegate()->FillTransportParameters(¶ms)) { return false; @@ -246,27 +279,41 @@ bool TlsClientHandshaker::ProcessTransportParameters( session()->connection()->OnTransportParametersReceived( *received_transport_params_); - // When interoperating with non-Google implementations that do not send - // the version extension, set it to what we expect. - if (received_transport_params_->version == 0) { - received_transport_params_->version = - CreateQuicVersionLabel(session()->connection()->version()); + if (received_transport_params_->legacy_version_information.has_value()) { + if (received_transport_params_->legacy_version_information.value() + .version != + CreateQuicVersionLabel(session()->connection()->version())) { + *error_details = "Version mismatch detected"; + return false; + } + if (CryptoUtils::ValidateServerHelloVersions( + received_transport_params_->legacy_version_information.value() + .supported_versions, + session()->connection()->server_supported_versions(), + error_details) != QUIC_NO_ERROR) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } } - if (received_transport_params_->supported_versions.empty()) { - received_transport_params_->supported_versions.push_back( - received_transport_params_->version); + if (received_transport_params_->version_information.has_value()) { + if (!CryptoUtils::ValidateChosenVersion( + received_transport_params_->version_information.value() + .chosen_version, + session()->version(), error_details)) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + if (!CryptoUtils::CryptoUtils::ValidateServerVersions( + received_transport_params_->version_information.value() + .other_versions, + session()->version(), + session()->client_original_supported_versions(), error_details)) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } } - if (received_transport_params_->version != - CreateQuicVersionLabel(session()->connection()->version())) { - *error_details = "Version mismatch detected"; - return false; - } - if (CryptoUtils::ValidateServerHelloVersions( - received_transport_params_->supported_versions, - session()->connection()->server_supported_versions(), - error_details) != QUIC_NO_ERROR || - handshaker_delegate()->ProcessTransportParameters( + if (handshaker_delegate()->ProcessTransportParameters( *received_transport_params_, /* is_resumption = */ false, error_details) != QUIC_NO_ERROR) { QUICHE_DCHECK(!error_details->empty()); @@ -315,6 +362,13 @@ std::string TlsClientHandshaker::chlo_hash() const { return ""; } +bool TlsClientHandshaker::ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + return ExportKeyingMaterialForLabel(label, context, result_len, result); +} + bool TlsClientHandshaker::encryption_established() const { return encryption_established_; } @@ -386,9 +440,15 @@ void TlsClientHandshaker::OnNewTokenReceived(absl::string_view token) { if (token.empty()) { return; } - QuicCryptoClientConfig::CachedState* cached = - crypto_config_->LookupOrCreate(server_id_); - cached->set_source_address_token(token); + if (GetQuicReloadableFlag(quic_tls_use_token_in_session_cache)) { + if (session_cache_ != nullptr) { + session_cache_->OnNewTokenReceived(server_id_, token); + } + } else { + QuicCryptoClientConfig::CachedState* cached = + crypto_config_->LookupOrCreate(server_id_); + cached->set_source_address_token(token); + } } void TlsClientHandshaker::SetWriteSecret( @@ -446,27 +506,10 @@ void TlsClientHandshaker::OnProofVerifyDetailsAvailable( } void TlsClientHandshaker::FinishHandshake() { - // Fill crypto_negotiated_params_: - const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); - if (cipher) { - crypto_negotiated_params_->cipher_suite = - SSL_CIPHER_get_protocol_id(cipher); - } - crypto_negotiated_params_->key_exchange_group = SSL_get_curve_id(ssl()); - crypto_negotiated_params_->peer_signature_algorithm = - SSL_get_peer_signature_algorithm(ssl()); - if (SSL_in_early_data(ssl())) { - // SSL_do_handshake returns after sending the ClientHello if the session is - // 0-RTT-capable, which means that FinishHandshake will get called twice - - // the first time after sending the ClientHello, and the second time after - // the handshake is complete. If we're in the first time FinishHandshake is - // called, we can't do any end-of-handshake processing. - - // If we're attempting a 0-RTT handshake, then we need to let the transport - // and application know what state to apply to early data. - PrepareZeroRttConfig(cached_state_.get()); - return; - } + FillNegotiatedParams(); + + QUICHE_CHECK(!SSL_in_early_data(ssl())); + QUIC_LOG(INFO) << "Client: handshake finished"; std::string error_details; @@ -505,20 +548,18 @@ void TlsClientHandshaker::FinishHandshake() { << "'"; // Parse ALPS extension. - if (enable_alps_) { - const uint8_t* alps_data; - size_t alps_length; - SSL_get0_peer_application_settings(ssl(), &alps_data, &alps_length); - if (alps_length > 0) { - auto error = session()->OnAlpsData(alps_data, alps_length); - if (error) { - // Calling CloseConnection() is safe even in case OnAlpsData() has - // already closed the connection. - CloseConnection( - QUIC_HANDSHAKE_FAILED, - absl::StrCat("Error processing ALPS data: ", error.value())); - return; - } + const uint8_t* alps_data; + size_t alps_length; + SSL_get0_peer_application_settings(ssl(), &alps_data, &alps_length); + if (alps_length > 0) { + auto error = session()->OnAlpsData(alps_data, alps_length); + if (error) { + // Calling CloseConnection() is safe even in case OnAlpsData() has + // already closed the connection. + CloseConnection( + QUIC_HANDSHAKE_FAILED, + absl::StrCat("Error processing ALPS data: ", error.value())); + return; } } @@ -526,6 +567,29 @@ void TlsClientHandshaker::FinishHandshake() { handshaker_delegate()->OnTlsHandshakeComplete(); } +void TlsClientHandshaker::OnEnterEarlyData() { + QUICHE_DCHECK(SSL_in_early_data(ssl())); + + // TODO(wub): It might be unnecessary to FillNegotiatedParams() at this time, + // because we fill it again when handshake completes. + FillNegotiatedParams(); + + // If we're attempting a 0-RTT handshake, then we need to let the transport + // and application know what state to apply to early data. + PrepareZeroRttConfig(cached_state_.get()); +} + +void TlsClientHandshaker::FillNegotiatedParams() { + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + if (cipher) { + crypto_negotiated_params_->cipher_suite = + SSL_CIPHER_get_protocol_id(cipher); + } + crypto_negotiated_params_->key_exchange_group = SSL_get_curve_id(ssl()); + crypto_negotiated_params_->peer_signature_algorithm = + SSL_get_peer_signature_algorithm(ssl()); +} + void TlsClientHandshaker::ProcessPostHandshakeMessage() { int rv = SSL_process_quic_post_handshake(ssl()); if (rv != 1) { diff --git a/gquiche/quic/core/tls_client_handshaker.h b/gquiche/quic/core/tls_client_handshaker.h index 474e539c..46a92223 100644 --- a/gquiche/quic/core/tls_client_handshaker.h +++ b/gquiche/quic/core/tls_client_handshaker.h @@ -50,6 +50,8 @@ class QUIC_EXPORT_PRIVATE TlsClientHandshaker bool ReceivedInchoateReject() const override; int num_scup_messages_received() const override; std::string chlo_hash() const override; + bool ExportKeyingMaterial(absl::string_view label, absl::string_view context, + size_t result_len, std::string* result) override; // From QuicCryptoClientStream::HandshakerInterface and TlsHandshaker bool encryption_established() const override; @@ -81,8 +83,9 @@ class QUIC_EXPORT_PRIVATE TlsClientHandshaker void AllowEmptyAlpnForTests() { allow_empty_alpn_for_tests_ = true; } void AllowInvalidSNIForTests() { allow_invalid_sni_for_tests_ = true; } - SSL* GetSslForTests() { return tls_connection_.ssl(); } - const SSL* GetSslForTests() const { return tls_connection_.ssl(); } + + // Make the SSL object from BoringSSL publicly accessible. + using TlsHandshaker::ssl; protected: const TlsConnection* tls_connection() const override { @@ -90,6 +93,8 @@ class QUIC_EXPORT_PRIVATE TlsClientHandshaker } void FinishHandshake() override; + void OnEnterEarlyData() override; + void FillNegotiatedParams(); void ProcessPostHandshakeMessage() override; bool ShouldCloseConnectionOnUnexpectedError(int ssl_error) override; QuicAsyncStatus VerifyCertChain( @@ -166,9 +171,6 @@ class QUIC_EXPORT_PRIVATE TlsClientHandshaker std::unique_ptr received_transport_params_ = nullptr; std::unique_ptr received_application_state_ = nullptr; - - // Latched value of reloadable flag quic_enable_alps_client. - const bool enable_alps_ = GetQuicReloadableFlag(quic_enable_alps_client); }; } // namespace quic diff --git a/gquiche/quic/core/tls_client_handshaker_test.cc b/gquiche/quic/core/tls_client_handshaker_test.cc index 7ed63f07..03794049 100644 --- a/gquiche/quic/core/tls_client_handshaker_test.cc +++ b/gquiche/quic/core/tls_client_handshaker_test.cc @@ -279,11 +279,7 @@ TEST_P(TlsClientHandshakerTest, ConnectedAfterHandshake) { TEST_P(TlsClientHandshakerTest, ConnectionClosedOnTlsError) { // Have client send ClientHello. stream()->CryptoConnect(); - if (GetQuicReloadableFlag(quic_send_tls_crypto_error_code)) { - EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _)); - } else { - EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); - } + EXPECT_CALL(*connection_, CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _)); // Send a zero-length ServerHello from server to client. char bogus_handshake_message[] = { @@ -423,6 +419,55 @@ TEST_P(TlsClientHandshakerTest, ZeroRttResumption) { EXPECT_EQ(stream()->EarlyDataReason(), ssl_early_data_accepted); } +// Regression test for b/186438140. +TEST_P(TlsClientHandshakerTest, ZeroRttResumptionWithAyncProofVerifier) { + // Finish establishing the first connection, so the second connection can + // resume. + CompleteCryptoHandshake(); + + EXPECT_EQ(PROTOCOL_TLS1_3, stream()->handshake_protocol()); + EXPECT_TRUE(stream()->encryption_established()); + EXPECT_TRUE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(stream()->IsResumption()); + + // Create a second connection. + CreateConnection(); + InitializeFakeServer(); + EXPECT_CALL(*session_, OnConfigNegotiated()); + EXPECT_CALL(*connection_, SendCryptoData(_, _, _)) + .Times(testing::AnyNumber()); + // Enable TestProofVerifier to capture the call to VerifyCertChain and run it + // asynchronously. + TestProofVerifier* proof_verifier = + static_cast(crypto_config_->proof_verifier()); + proof_verifier->Activate(); + // Start the second handshake. + stream()->CryptoConnect(); + + ASSERT_EQ(proof_verifier->NumPendingCallbacks(), 1u); + + // Advance the handshake with the server. Since cert verification has not + // finished yet, client cannot derive HANDSHAKE and 1-RTT keys. + crypto_test_utils::AdvanceHandshake(connection_, stream(), 0, + server_connection_, server_stream(), 0); + + EXPECT_FALSE(stream()->one_rtt_keys_available()); + EXPECT_FALSE(server_stream()->one_rtt_keys_available()); + + // Finish cert verification after receiving packets from server. + proof_verifier->InvokePendingCallback(0); + + QuicFramer* framer = QuicConnectionPeer::GetFramer(connection_); + // Verify client has derived HANDSHAKE key. + EXPECT_NE(nullptr, + QuicFramerPeer::GetEncrypter(framer, ENCRYPTION_HANDSHAKE)); + + // Ideally, we should also verify that the process_undecryptable_packets_alarm + // is set and processing the undecryptable packets can advance the handshake + // to completion. Unfortunately, the test facilities used in this test does + // not support queuing and processing undecryptable packets. +} + TEST_P(TlsClientHandshakerTest, ZeroRttRejection) { // Finish establishing the first connection: CompleteCryptoHandshake(); @@ -565,23 +610,14 @@ TEST_P(TlsClientHandshakerTest, ServerRequiresCustomALPN) { .WillOnce([kTestAlpn](const std::vector& alpns) { return std::find(alpns.cbegin(), alpns.cend(), kTestAlpn); }); - if (GetQuicReloadableFlag(quic_send_tls_crypto_error_code)) { - EXPECT_CALL( - *server_connection_, - CloseConnection( - QUIC_HANDSHAKE_FAILED, - static_cast(CRYPTO_ERROR_FIRST + 120), - "TLS handshake failure (ENCRYPTION_INITIAL) 120: " - "no application protocol", - _)); - } else { - EXPECT_CALL( - *server_connection_, - CloseConnection(QUIC_HANDSHAKE_FAILED, - "TLS handshake failure (ENCRYPTION_INITIAL) 120: " - "no application protocol", - _)); - } + + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, + static_cast( + CRYPTO_ERROR_FIRST + 120), + "TLS handshake failure (ENCRYPTION_INITIAL) 120: " + "no application protocol", + _)); stream()->CryptoConnect(); crypto_test_utils::AdvanceHandshake(connection_, stream(), 0, diff --git a/gquiche/quic/core/tls_handshaker.cc b/gquiche/quic/core/tls_handshaker.cc index 416251c9..f7c126e4 100644 --- a/gquiche/quic/core/tls_handshaker.cc +++ b/gquiche/quic/core/tls_handshaker.cc @@ -12,9 +12,12 @@ #include "gquiche/quic/core/quic_crypto_stream.h" #include "gquiche/quic/core/tls_client_handshaker.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" +#include "gquiche/quic/platform/api/quic_stack_trace.h" namespace quic { +#define ENDPOINT (SSL_is_server(ssl()) ? "TlsServer: " : "TlsClient: ") + TlsHandshaker::ProofVerifierCallbackImpl::ProofVerifierCallbackImpl( TlsHandshaker* parent) : parent_(parent) {} @@ -93,15 +96,43 @@ void TlsHandshaker::AdvanceHandshake() { return; } - QUICHE_BUG_IF(quic_tls_server_async_done_no_flusher, - SSL_is_server(ssl()) && add_packet_flusher_on_async_op_done_ && - !handshaker_delegate_->PacketFlusherAttached()) - << "is_server:" << SSL_is_server(ssl()) - << ", add_packet_flusher_on_async_op_done_:" - << add_packet_flusher_on_async_op_done_; + QUICHE_BUG_IF( + quic_tls_server_async_done_no_flusher, + SSL_is_server(ssl()) && !handshaker_delegate_->PacketFlusherAttached()) + << "is_server:" << SSL_is_server(ssl()); - QUIC_VLOG(1) << "TlsHandshaker: continuing handshake"; + QUIC_VLOG(1) << ENDPOINT << "Continuing handshake"; int rv = SSL_do_handshake(ssl()); + + // If SSL_do_handshake return success(1) and we are in early data, it is + // possible that we have provided ServerHello to BoringSSL but it hasn't been + // processed. Retry SSL_do_handshake once will advance the handshake more in + // that case. If there are no unprocessed ServerHello, the retry will return a + // non-positive number. + if (rv == 1 && SSL_in_early_data(ssl())) { + OnEnterEarlyData(); + rv = SSL_do_handshake(ssl()); + QUIC_VLOG(1) << ENDPOINT + << "SSL_do_handshake returned when entering early data. After " + << "retry, rv=" << rv + << ", SSL_in_early_data=" << SSL_in_early_data(ssl()); + // The retry should either + // - Return <= 0 if the handshake is still pending, likely still in early + // data. + // - Return 1 if the handshake has _actually_ finished. i.e. + // SSL_in_early_data should be false. + // + // In either case, it should not both return 1 and stay in early data. + if (rv == 1 && SSL_in_early_data(ssl()) && !is_connection_closed_) { + QUIC_BUG(quic_handshaker_stay_in_early_data) + << "The original and the retry of SSL_do_handshake both returned " + "success and in early data"; + CloseConnection(QUIC_HANDSHAKE_FAILED, + "TLS handshake failed: Still in early data after retry"); + return; + } + } + if (rv == 1) { FinishHandshake(); return; @@ -176,6 +207,7 @@ enum ssl_verify_result_t TlsHandshaker::VerifyCert(uint8_t* out_alert) { std::string(reinterpret_cast(CRYPTO_BUFFER_data(cert)), CRYPTO_BUFFER_len(cert))); } + QUIC_DVLOG(1) << "VerifyCert: peer cert_chain length: " << certs.size(); ProofVerifierCallbackImpl* proof_verify_callback = new ProofVerifierCallbackImpl(this); @@ -207,7 +239,7 @@ enum ssl_verify_result_t TlsHandshaker::VerifyCert(uint8_t* out_alert) { void TlsHandshaker::SetWriteSecret(EncryptionLevel level, const SSL_CIPHER* cipher, const std::vector& write_secret) { - QUIC_DVLOG(1) << "SetWriteSecret level=" << level; + QUIC_DVLOG(1) << ENDPOINT << "SetWriteSecret level=" << level; std::unique_ptr encrypter = QuicEncrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); const EVP_MD* prf = Prf(cipher); @@ -230,7 +262,7 @@ void TlsHandshaker::SetWriteSecret(EncryptionLevel level, bool TlsHandshaker::SetReadSecret(EncryptionLevel level, const SSL_CIPHER* cipher, const std::vector& read_secret) { - QUIC_DVLOG(1) << "SetReadSecret level=" << level; + QUIC_DVLOG(1) << ENDPOINT << "SetReadSecret level=" << level; std::unique_ptr decrypter = QuicDecrypter::CreateFromCipherSuite(SSL_CIPHER_get_id(cipher)); const EVP_MD* prf = Prf(cipher); @@ -297,6 +329,23 @@ std::unique_ptr TlsHandshaker::CreateCurrentOneRttEncrypter() { return encrypter; } +bool TlsHandshaker::ExportKeyingMaterialForLabel(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + // TODO(haoyuewang) Adding support of keying material export when 0-RTT is + // accepted. + if (SSL_in_init(ssl())) { + return false; + } + result->resize(result_len); + return SSL_export_keying_material( + ssl(), reinterpret_cast(&*result->begin()), result_len, + label.data(), label.size(), + reinterpret_cast(context.data()), context.size(), + !context.empty()) == 1; +} + void TlsHandshaker::WriteMessage(EncryptionLevel level, absl::string_view data) { stream_->WriteCryptoData(level, data); @@ -309,15 +358,10 @@ void TlsHandshaker::SendAlert(EncryptionLevel level, uint8_t desc) { "TLS handshake failure (", EncryptionLevelToString(level), ") ", static_cast(desc), ": ", SSL_alert_desc_string_long(desc)); QUIC_DLOG(ERROR) << error_details; - if (GetQuicReloadableFlag(quic_send_tls_crypto_error_code)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_send_tls_crypto_error_code); - CloseConnection( - TlsAlertToQuicErrorCode(desc), - static_cast(CRYPTO_ERROR_FIRST + desc), - error_details); - } else { - CloseConnection(QUIC_HANDSHAKE_FAILED, error_details); - } + CloseConnection( + TlsAlertToQuicErrorCode(desc), + static_cast(CRYPTO_ERROR_FIRST + desc), + error_details); } } // namespace quic diff --git a/gquiche/quic/core/tls_handshaker.h b/gquiche/quic/core/tls_handshaker.h index e19debda..9b1e5d79 100644 --- a/gquiche/quic/core/tls_handshaker.h +++ b/gquiche/quic/core/tls_handshaker.h @@ -53,11 +53,14 @@ class QUIC_EXPORT_PRIVATE TlsHandshaker : public TlsConnection::Delegate, std::unique_ptr AdvanceKeysAndCreateCurrentOneRttDecrypter(); std::unique_ptr CreateCurrentOneRttEncrypter(); virtual HandshakeState GetHandshakeState() const = 0; + bool ExportKeyingMaterialForLabel(absl::string_view label, + absl::string_view context, + size_t result_len, std::string* result); protected: // Called when a new message is received on the crypto stream and is available // for the TLS stack to read. - void AdvanceHandshake(); + virtual void AdvanceHandshake(); void CloseConnection(QuicErrorCode error, const std::string& reason_phrase); // Closes the connection, specifying the wire error code |ietf_error| @@ -71,11 +74,18 @@ class QUIC_EXPORT_PRIVATE TlsHandshaker : public TlsConnection::Delegate, bool is_connection_closed() const { return is_connection_closed_; } // Called when |SSL_do_handshake| returns 1, indicating that the handshake has - // finished. Note that due to 0-RTT, the handshake may "finish" twice; - // |SSL_in_early_data| can be used to determine whether the handshake is truly - // done. + // finished. Note that a handshake only finishes once, entering early data + // does not count. virtual void FinishHandshake() = 0; + // Called when |SSL_do_handshake| returns 1 and the connection is in early + // data. In that case, |AdvanceHandshake| will call |OnEnterEarlyData| and + // retry |SSL_do_handshake| once. + virtual void OnEnterEarlyData() { + // By default, do nothing but check the preconditions. + QUICHE_DCHECK(SSL_in_early_data(ssl())); + } + // Called when a handshake message is received after the handshake is // complete. virtual void ProcessPostHandshakeMessage() = 0; @@ -90,12 +100,11 @@ class QUIC_EXPORT_PRIVATE TlsHandshaker : public TlsConnection::Delegate, } int expected_ssl_error() const { return expected_ssl_error_; } - // Called to verify a cert chain. This is a simple wrapper around - // ProofVerifier or ServerProofVerifier, which optionally gathers additional - // arguments to pass into their VerifyCertChain method. This class retains a - // non-owning pointer to |callback|; the callback must live until this - // function returns QUIC_SUCCESS or QUIC_FAILURE, or until the callback is - // run. + // Called to verify a cert chain. This can be implemented as a simple wrapper + // around ProofVerifier, which optionally gathers additional arguments to pass + // into their VerifyCertChain method. This class retains a non-owning pointer + // to |callback|; the callback must live until this function returns + // QUIC_SUCCESS or QUIC_FAILURE, or until the callback is run. // // If certificate verification fails, |*out_alert| may be set to a TLS alert // that will be sent when closing the connection; it defaults to @@ -156,8 +165,10 @@ class QUIC_EXPORT_PRIVATE TlsHandshaker : public TlsConnection::Delegate, // error code corresponding to the TLS alert description |desc|. void SendAlert(EncryptionLevel level, uint8_t desc) override; - const bool add_packet_flusher_on_async_op_done_ = - GetQuicReloadableFlag(quic_add_packet_flusher_on_async_op_done); + // Informational callback from BoringSSL. Subclasses can override it to do + // logging, tracing, etc. + // See |SSL_CTX_set_info_callback| for the meaning of |type| and |value|. + void InfoCallback(int /*type*/, int /*value*/) override {} private: // ProofVerifierCallbackImpl handles the result of an asynchronous certificate diff --git a/gquiche/quic/core/tls_server_handshaker.cc b/gquiche/quic/core/tls_server_handshaker.cc index be98da26..27f9100a 100644 --- a/gquiche/quic/core/tls_server_handshaker.cc +++ b/gquiche/quic/core/tls_server_handshaker.cc @@ -8,6 +8,7 @@ #include #include "absl/base/macros.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "openssl/pool.h" #include "openssl/ssl.h" @@ -22,7 +23,6 @@ #include "gquiche/quic/platform/api/quic_hostname_utils.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_server_stats.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #define RECORD_LATENCY_IN_US(stat_name, latency, comment) \ do { \ @@ -34,17 +34,24 @@ namespace quic { +namespace { + +// Default port for HTTP/3. +uint16_t kDefaultPort = 443; + +} // namespace + TlsServerHandshaker::DefaultProofSourceHandle::DefaultProofSourceHandle( TlsServerHandshaker* handshaker, ProofSource* proof_source) : handshaker_(handshaker), proof_source_(proof_source) {} TlsServerHandshaker::DefaultProofSourceHandle::~DefaultProofSourceHandle() { - CancelPendingOperation(); + CloseHandle(); } -void TlsServerHandshaker::DefaultProofSourceHandle::CancelPendingOperation() { - QUIC_DVLOG(1) << "CancelPendingOperation. is_signature_pending=" +void TlsServerHandshaker::DefaultProofSourceHandle::CloseHandle() { + QUIC_DVLOG(1) << "CloseHandle. is_signature_pending=" << (signature_callback_ != nullptr); if (signature_callback_) { signature_callback_->Cancel(); @@ -56,22 +63,30 @@ QuicAsyncStatus TlsServerHandshaker::DefaultProofSourceHandle::SelectCertificate( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, + absl::string_view /*ssl_capabilities*/, const std::string& hostname, absl::string_view /*client_hello*/, const std::string& /*alpn*/, + absl::optional /*alps*/, const std::vector& /*quic_transport_params*/, - const absl::optional>& /*early_data_context*/) { + const absl::optional>& /*early_data_context*/, + const QuicSSLConfig& /*ssl_config*/) { if (!handshaker_ || !proof_source_) { QUIC_BUG(quic_bug_10341_1) << "SelectCertificate called on a detached handle"; return QUIC_FAILURE; } + bool cert_matched_sni; QuicReferenceCountedPointer chain = - proof_source_->GetCertChain(server_address, client_address, hostname); + proof_source_->GetCertChain(server_address, client_address, hostname, + &cert_matched_sni); handshaker_->OnSelectCertificateDone( - /*ok=*/true, /*is_sync=*/true, chain.get()); + /*ok=*/true, /*is_sync=*/true, chain.get(), + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), cert_matched_sni, + QuicDelayedSSLConfig()); if (!handshaker_->select_cert_status().has_value()) { QUIC_BUG(quic_bug_12423_1) << "select_cert_status() has no value after a synchronous select cert"; @@ -126,26 +141,39 @@ void TlsServerHandshaker::DecryptCallback::Run(std::vector plaintext) { // The callback was cancelled before we could run. return; } - handshaker_->decrypted_session_ticket_ = std::move(plaintext); + + TlsServerHandshaker* handshaker = handshaker_; + handshaker_ = nullptr; + + handshaker->decrypted_session_ticket_ = std::move(plaintext); + const bool is_async = + (handshaker->expected_ssl_error() == SSL_ERROR_PENDING_TICKET); + + absl::optional context_switcher; + + if (is_async) { + context_switcher.emplace(handshaker->connection_context()); + } + QUIC_TRACESTRING( + absl::StrCat("TLS ticket decryption done. len(decrypted_ticket):", + handshaker->decrypted_session_ticket_.size())); + // DecryptCallback::Run could be called synchronously. When that happens, we // are currently in the middle of a call to AdvanceHandshake. - // (AdvanceHandshake called SSL_do_handshake, which through some layers called - // SessionTicketOpen, which called TicketCrypter::Decrypt, which synchronously - // called this function.) In that case, the handshake will continue to be - // processed when this function returns. + // (AdvanceHandshake called SSL_do_handshake, which through some layers + // called SessionTicketOpen, which called TicketCrypter::Decrypt, which + // synchronously called this function.) In that case, the handshake will + // continue to be processed when this function returns. // - // When this callback is called asynchronously (i.e. the ticket decryption is - // pending), TlsServerHandshaker is not actively processing handshake + // When this callback is called asynchronously (i.e. the ticket decryption + // is pending), TlsServerHandshaker is not actively processing handshake // messages. We need to have it resume processing handshake messages by // calling AdvanceHandshake. - if (handshaker_->expected_ssl_error() == SSL_ERROR_PENDING_TICKET) { - handshaker_->AdvanceHandshakeFromCallback(); + if (is_async) { + handshaker->AdvanceHandshakeFromCallback(); } - // The TicketDecrypter took ownership of this callback when Decrypt was - // called. Once the callback returns, it will be deleted. Remove the - // (non-owning) pointer to the callback from the handshaker so the handshaker - // doesn't have an invalid pointer hanging around. - handshaker_->ticket_decryption_callback_ = nullptr; + + handshaker->ticket_decryption_callback_ = nullptr; } void TlsServerHandshaker::DecryptCallback::Cancel() { @@ -154,15 +182,18 @@ void TlsServerHandshaker::DecryptCallback::Cancel() { } TlsServerHandshaker::TlsServerHandshaker( - QuicSession* session, - const QuicCryptoServerConfig* crypto_config) + QuicSession* session, const QuicCryptoServerConfig* crypto_config) : TlsHandshaker(this, session), QuicCryptoServerStreamBase(session), proof_source_(crypto_config->proof_source()), pre_shared_key_(crypto_config->pre_shared_key()), crypto_negotiated_params_(new QuicCryptoNegotiatedParameters), - tls_connection_(crypto_config->ssl_ctx(), this), + tls_connection_(crypto_config->ssl_ctx(), this, session->GetSSLConfig()), crypto_config_(crypto_config) { + QUIC_DVLOG(1) << "TlsServerHandshaker: support_client_cert:" + << session->support_client_cert() + << ", client_cert_mode initial value: " << client_cert_mode(); + QUICHE_DCHECK_EQ(PROTOCOL_TLS1_3, session->connection()->version().handshake_protocol); @@ -176,18 +207,20 @@ TlsServerHandshaker::TlsServerHandshaker( } SSL_set_quic_use_legacy_codepoint(ssl(), use_legacy_extension); - if (GetQuicFlag(FLAGS_quic_disable_server_tls_resumption)) { - SSL_set_options(ssl(), SSL_OP_NO_TICKET); + if (session->connection()->context()->tracer) { + tls_connection_.EnableInfoCallback(); } -} -TlsServerHandshaker::~TlsServerHandshaker() { - CancelOutstandingCallbacks(); + if (no_select_cert_if_disconnected_) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_tls_no_select_cert_if_disconnected, 1, 2); + } } +TlsServerHandshaker::~TlsServerHandshaker() { CancelOutstandingCallbacks(); } + void TlsServerHandshaker::CancelOutstandingCallbacks() { if (proof_source_handle_) { - proof_source_handle_->CancelPendingOperation(); + proof_source_handle_->CloseHandle(); } if (ticket_decryption_callback_) { ticket_decryption_callback_->Cancel(); @@ -195,6 +228,40 @@ void TlsServerHandshaker::CancelOutstandingCallbacks() { } } +void TlsServerHandshaker::InfoCallback(int type, int value) { + QuicConnectionTracer* tracer = + session()->connection()->context()->tracer.get(); + + if (tracer == nullptr) { + return; + } + + if (type & SSL_CB_LOOP) { + tracer->PrintString( + absl::StrCat("SSL:ACCEPT_LOOP:", SSL_state_string_long(ssl()))); + } else if (type & SSL_CB_ALERT) { + const char* prefix = + (type & SSL_CB_READ) ? "SSL:READ_ALERT:" : "SSL:WRITE_ALERT:"; + tracer->PrintString(absl::StrCat(prefix, SSL_alert_type_string_long(value), + ":", SSL_alert_desc_string_long(value))); + } else if (type & SSL_CB_EXIT) { + const char* prefix = + (value == 1) ? "SSL:ACCEPT_EXIT_OK:" : "SSL:ACCEPT_EXIT_FAIL:"; + tracer->PrintString(absl::StrCat(prefix, SSL_state_string_long(ssl()))); + } else if (type & SSL_CB_HANDSHAKE_START) { + tracer->PrintString( + absl::StrCat("SSL:HANDSHAKE_START:", SSL_state_string_long(ssl()))); + } else if (type & SSL_CB_HANDSHAKE_DONE) { + tracer->PrintString( + absl::StrCat("SSL:HANDSHAKE_DONE:", SSL_state_string_long(ssl()))); + } else { + QUIC_DLOG(INFO) << "Unknown event type " << type << ": " + << SSL_state_string_long(ssl()); + tracer->PrintString( + absl::StrCat("SSL:unknown:", value, ":", SSL_state_string_long(ssl()))); + } +} + std::unique_ptr TlsServerHandshaker::MaybeCreateProofSourceHandle() { return std::make_unique(this, proof_source_); @@ -230,11 +297,14 @@ int TlsServerHandshaker::NumServerConfigUpdateMessagesSent() const { const CachedNetworkParameters* TlsServerHandshaker::PreviousCachedNetworkParams() const { - return nullptr; + return last_received_cached_network_params_.get(); } void TlsServerHandshaker::SetPreviousCachedNetworkParams( - CachedNetworkParameters /*cached_network_params*/) {} + CachedNetworkParameters cached_network_params) { + last_received_cached_network_params_ = + std::make_unique(cached_network_params); +} void TlsServerHandshaker::OnPacketDecrypted(EncryptionLevel level) { if (level == ENCRYPTION_HANDSHAKE && state_ < HANDSHAKE_PROCESSED) { @@ -252,34 +322,43 @@ void TlsServerHandshaker::OnNewTokenReceived(absl::string_view /*token*/) { QUICHE_DCHECK(false); } -std::string TlsServerHandshaker::GetAddressToken() const { +std::string TlsServerHandshaker::GetAddressToken( + const CachedNetworkParameters* cached_network_params) const { SourceAddressTokens empty_previous_tokens; const QuicConnection* connection = session()->connection(); return crypto_config_->NewSourceAddressToken( crypto_config_->source_address_token_boxer(), empty_previous_tokens, connection->effective_peer_address().host(), connection->random_generator(), connection->clock()->WallNow(), - /*cached_network_params=*/nullptr); + session()->add_cached_network_parameters_to_address_token() + ? cached_network_params + : nullptr); } bool TlsServerHandshaker::ValidateAddressToken(absl::string_view token) const { SourceAddressTokens tokens; HandshakeFailureReason reason = crypto_config_->ParseSourceAddressToken( - crypto_config_->source_address_token_boxer(), token, &tokens); + crypto_config_->source_address_token_boxer(), token, tokens); if (reason != HANDSHAKE_OK) { QUIC_DLOG(WARNING) << "Failed to parse source address token: " << CryptoUtils::HandshakeFailureReasonToString(reason); return false; } + auto cached_network_params = std::make_unique(); reason = crypto_config_->ValidateSourceAddressTokens( tokens, session()->connection()->effective_peer_address().host(), session()->connection()->clock()->WallNow(), - /*cached_network_params=*/nullptr); + session()->add_cached_network_parameters_to_address_token() + ? cached_network_params.get() + : nullptr); if (reason != HANDSHAKE_OK) { QUIC_DLOG(WARNING) << "Failed to validate source address token: " << CryptoUtils::HandshakeFailureReasonToString(reason); return false; } + if (session()->add_cached_network_parameters_to_address_token()) { + last_received_cached_network_params_ = std::move(cached_network_params); + } return true; } @@ -287,10 +366,19 @@ bool TlsServerHandshaker::ShouldSendExpectCTHeader() const { return false; } +bool TlsServerHandshaker::DidCertMatchSni() const { return cert_matched_sni_; } + const ProofSource::Details* TlsServerHandshaker::ProofSourceDetails() const { return proof_source_details_.get(); } +bool TlsServerHandshaker::ExportKeyingMaterial(absl::string_view label, + absl::string_view context, + size_t result_len, + std::string* result) { + return ExportKeyingMaterialForLabel(label, context, result_len, result); +} + void TlsServerHandshaker::OnConnectionClosed(QuicErrorCode error, ConnectionCloseSource source) { TlsHandshaker::OnConnectionClosed(error, source); @@ -348,18 +436,7 @@ TlsServerHandshaker::CreateCurrentOneRttEncrypter() { void TlsServerHandshaker::OverrideQuicConfigDefaults(QuicConfig* /*config*/) {} void TlsServerHandshaker::AdvanceHandshakeFromCallback() { - std::unique_ptr flusher; - if (add_packet_flusher_on_async_op_done_) { - if (session()->PacketFlusherAttached()) { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_add_packet_flusher_on_async_op_done, 1, - 2); - } else { - QUIC_RELOADABLE_FLAG_COUNT_N(quic_add_packet_flusher_on_async_op_done, 2, - 2); - } - flusher = std::make_unique( - session()->connection()); - } + QuicConnection::ScopedPacketFlusher flusher(session()->connection()); AdvanceHandshake(); if (!is_connection_closed()) { @@ -406,34 +483,46 @@ bool TlsServerHandshaker::ProcessTransportParameters( // Notify QuicConnectionDebugVisitor. session()->connection()->OnTransportParametersReceived(client_params); - // Chrome clients before 86.0.4233.0 did not send the - // key_update_not_yet_supported transport parameter, but they did send a - // Google-internal transport parameter with identifier 0x4751. We treat - // reception of 0x4751 as having received key_update_not_yet_supported to - // ensure we do not use key updates with those older clients. - // TODO(dschinazi) remove this workaround once all of our QUIC+TLS Finch - // experiments have a min_version greater than 86.0.4233.0. - if (client_params.custom_parameters.find( - static_cast(0x4751)) != - client_params.custom_parameters.end()) { - client_params.key_update_not_yet_supported = true; - } - - // When interoperating with non-Google implementations that do not send - // the version extension, set it to what we expect. - if (client_params.version == 0) { - client_params.version = - CreateQuicVersionLabel(session()->connection()->version()); - } - - if (CryptoUtils::ValidateClientHelloVersion( - client_params.version, session()->connection()->version(), - session()->supported_versions(), error_details) != QUIC_NO_ERROR || - handshaker_delegate()->ProcessTransportParameters( + if (GetQuicReloadableFlag(quic_ignore_key_update_not_yet_supported)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_ignore_key_update_not_yet_supported, 2, + 2); + } else { + // Chrome clients before 86.0.4233.0 did not send the + // key_update_not_yet_supported transport parameter, but they did send a + // Google-internal transport parameter with identifier 0x4751. We treat + // reception of 0x4751 as having received key_update_not_yet_supported to + // ensure we do not use key updates with those older clients. + // TODO(dschinazi) remove this workaround once all of our QUIC+TLS Finch + // experiments have a min_version greater than 86.0.4233.0. + if (client_params.custom_parameters.find( + static_cast(0x4751)) != + client_params.custom_parameters.end()) { + client_params.key_update_not_yet_supported = true; + } + } + + if (client_params.legacy_version_information.has_value() && + CryptoUtils::ValidateClientHelloVersion( + client_params.legacy_version_information.value().version, + session()->connection()->version(), session()->supported_versions(), + error_details) != QUIC_NO_ERROR) { + return false; + } + + if (client_params.version_information.has_value() && + !CryptoUtils::ValidateChosenVersion( + client_params.version_information.value().chosen_version, + session()->version(), error_details)) { + QUICHE_DCHECK(!error_details->empty()); + return false; + } + + if (handshaker_delegate()->ProcessTransportParameters( client_params, /* is_resumption = */ false, error_details) != - QUIC_NO_ERROR) { + QUIC_NO_ERROR) { return false; } + ProcessAdditionalTransportParameters(client_params); if (!session()->user_agent_id().has_value() && client_params.user_agent_id.has_value()) { @@ -450,10 +539,21 @@ TlsServerHandshaker::SetTransportParameters() { TransportParameters server_params; server_params.perspective = Perspective::IS_SERVER; - server_params.supported_versions = + server_params.legacy_version_information = + TransportParameters::LegacyVersionInformation(); + server_params.legacy_version_information.value().supported_versions = CreateQuicVersionLabelVector(session()->supported_versions()); - server_params.version = + server_params.legacy_version_information.value().version = CreateQuicVersionLabel(session()->connection()->version()); + if (GetQuicReloadableFlag(quic_version_information)) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_version_information, 1, 2); + server_params.version_information = + TransportParameters::VersionInformation(); + server_params.version_information.value().chosen_version = + CreateQuicVersionLabel(session()->version()); + server_params.version_information.value().other_versions = + CreateQuicVersionLabelVector(session()->supported_versions()); + } if (!handshaker_delegate()->FillTransportParameters(&server_params)) { return result; @@ -511,22 +611,14 @@ void TlsServerHandshaker::SetWriteSecret( TlsHandshaker::SetWriteSecret(level, cipher, write_secret); } -std::string TlsServerHandshaker::GetAcceptChValueForOrigin( - const std::string& /*origin*/) const { +std::string TlsServerHandshaker::GetAcceptChValueForHostname( + const std::string& /*hostname*/) const { return {}; } void TlsServerHandshaker::FinishHandshake() { - if (SSL_in_early_data(ssl())) { - // If the server accepts early data, SSL_do_handshake returns success twice: - // once after processing the ClientHello and sending the server's first - // flight, and then again after the handshake is complete. This results in - // FinishHandshake getting called twice. On the first call to - // FinishHandshake, we don't have any confirmation that the client is live, - // so all end of handshake processing is deferred until the handshake is - // actually complete. - return; - } + QUICHE_DCHECK(!SSL_in_early_data(ssl())); + if (!valid_alpn_received_) { QUIC_DLOG(ERROR) << "Server: handshake finished without receiving a known ALPN"; @@ -559,9 +651,18 @@ QuicAsyncStatus TlsServerHandshaker::VerifyCertChain( std::unique_ptr* /*details*/, uint8_t* /*out_alert*/, std::unique_ptr /*callback*/) { - QUIC_BUG(quic_bug_10341_5) - << "Client certificates are not yet supported on the server"; - return QUIC_FAILURE; + if (!session()->support_client_cert()) { + QUIC_BUG(quic_bug_10341_5) + << "Client certificates are not yet supported on the server"; + return QUIC_FAILURE; + } + + QUIC_RESTART_FLAG_COUNT_N(quic_tls_server_support_client_cert, 2, 2); + QUIC_DVLOG(1) << "VerifyCertChain returning success"; + + // No real verification here. A subclass can override this function to verify + // the client cert if needed. + return QUIC_SUCCESS; } void TlsServerHandshaker::OnProofVerifyDetailsAvailable( @@ -577,7 +678,7 @@ ssl_private_key_result_t TlsServerHandshaker::PrivateKeySign( QuicAsyncStatus status = proof_source_handle_->ComputeSignature( session()->connection()->self_address(), - session()->connection()->peer_address(), cert_selection_hostname(), + session()->connection()->peer_address(), crypto_negotiated_params_->sni, sig_alg, in, max_out); if (status == QUIC_PENDING) { set_expected_ssl_error(SSL_ERROR_WANT_PRIVATE_KEY_OPERATION); @@ -632,6 +733,15 @@ void TlsServerHandshaker::OnComputeSignatureDone( QUIC_DVLOG(1) << "OnComputeSignatureDone. ok:" << ok << ", is_sync:" << is_sync << ", len(signature):" << signature.size(); + absl::optional context_switcher; + + if (!is_sync) { + context_switcher.emplace(connection_context()); + } + + QUIC_TRACESTRING(absl::StrCat("TLS compute signature done. ok:", ok, + ", len(signature):", signature.size())); + if (ok) { cert_verify_sig_ = std::move(signature); proof_source_details_ = std::move(details); @@ -660,7 +770,8 @@ int TlsServerHandshaker::SessionTicketSeal(uint8_t* out, size_t max_out_len, absl::string_view in) { QUICHE_DCHECK(proof_source_->GetTicketCrypter()); - std::vector ticket = proof_source_->GetTicketCrypter()->Encrypt(in); + std::vector ticket = + proof_source_->GetTicketCrypter()->Encrypt(in, ticket_encryption_key_); if (max_out_len < ticket.size()) { QUIC_BUG(quic_bug_12423_2) << "TicketCrypter returned " << ticket.size() @@ -681,20 +792,29 @@ ssl_ticket_aead_result_t TlsServerHandshaker::SessionTicketOpen( absl::string_view in) { QUICHE_DCHECK(proof_source_->GetTicketCrypter()); + if (ignore_ticket_open_) { + // SetIgnoreTicketOpen has been called. Typically this means the caller is + // using handshake hints and expect the hints to contain ticket decryption + // results. + QUIC_CODE_COUNT(quic_tls_server_handshaker_tickets_ignored_1); + return ssl_ticket_aead_ignore_ticket; + } + if (!ticket_decryption_callback_) { - ticket_received_ = true; ticket_decryption_callback_ = new DecryptCallback(this); proof_source_->GetTicketCrypter()->Decrypt( in, std::unique_ptr(ticket_decryption_callback_)); + // Decrypt can run the callback synchronously. In that case, the callback // will clear the ticket_decryption_callback_ pointer, and instead of - // returning ssl_ticket_aead_retry, we should continue processing to return - // the decrypted ticket. + // returning ssl_ticket_aead_retry, we should continue processing to + // return the decrypted ticket. // // If the callback is not run synchronously, return ssl_ticket_aead_retry // and when the callback is complete this function will be run again to // return the result. if (ticket_decryption_callback_) { + QUICHE_DCHECK(!ticket_decryption_callback_->IsDone()); set_expected_ssl_error(SSL_ERROR_PENDING_TICKET); if (async_op_timer_.has_value()) { QUIC_CODE_COUNT( @@ -702,10 +822,16 @@ ssl_ticket_aead_result_t TlsServerHandshaker::SessionTicketOpen( } async_op_timer_ = QuicTimeAccumulator(); async_op_timer_->Start(now()); - return ssl_ticket_aead_retry; } } + // If the async ticket decryption is pending, either started by this + // SessionTicketOpen call or one that happened earlier, return + // ssl_ticket_aead_retry. + if (ticket_decryption_callback_ && !ticket_decryption_callback_->IsDone()) { + return ssl_ticket_aead_retry; + } + ssl_ticket_aead_result_t result = FinalizeSessionTicketOpen(out, out_len, max_out_len); @@ -734,7 +860,7 @@ ssl_ticket_aead_result_t TlsServerHandshaker::FinalizeSessionTicketOpen( if (decrypted_session_ticket_.empty()) { QUIC_DLOG(ERROR) << "Session ticket decryption failed; ignoring ticket"; // Ticket decryption failed. Ignore the ticket. - QUIC_CODE_COUNT(quic_tls_server_handshaker_tickets_ignored); + QUIC_CODE_COUNT(quic_tls_server_handshaker_tickets_ignored_2); return ssl_ticket_aead_ignore_ticket; } if (max_out_len < decrypted_session_ticket_.size()) { @@ -775,22 +901,29 @@ ssl_select_cert_result_t TlsServerHandshaker::EarlySelectCertCallback( return ssl_select_cert_error; } + { + const uint8_t* unused_extension_bytes; + size_t unused_extension_len; + ticket_received_ = SSL_early_callback_ctx_extension_get( + client_hello, TLSEXT_TYPE_pre_shared_key, &unused_extension_bytes, + &unused_extension_len); + } + // This callback is called very early by Boring SSL, most of the SSL_get_foo // function do not work at this point, but SSL_get_servername does. const char* hostname = SSL_get_servername(ssl(), TLSEXT_NAMETYPE_host_name); if (hostname) { - hostname_ = hostname; crypto_negotiated_params_->sni = - QuicHostnameUtils::NormalizeHostname(hostname_); - if (!ValidateHostname(hostname_)) { + QuicHostnameUtils::NormalizeHostname(hostname); + if (!ValidateHostname(hostname)) { return ssl_select_cert_error; } - if (hostname_ != crypto_negotiated_params_->sni) { + if (hostname != crypto_negotiated_params_->sni) { QUIC_CODE_COUNT(quic_tls_server_hostname_diff); QUIC_LOG_EVERY_N_SEC(WARNING, 300) << "Raw and normalized hostnames differ, but both are valid SNIs. " "raw hostname:" - << hostname_ << ", normalized:" << crypto_negotiated_params_->sni; + << hostname << ", normalized:" << crypto_negotiated_params_->sni; } else { QUIC_CODE_COUNT(quic_tls_server_hostname_same); } @@ -812,15 +945,48 @@ ssl_select_cert_result_t TlsServerHandshaker::EarlySelectCertCallback( return ssl_select_cert_error; } + bssl::UniquePtr ssl_capabilities; + size_t ssl_capabilities_len = 0; + absl::string_view ssl_capabilities_view; + + absl::optional alps; + + if (CryptoUtils::GetSSLCapabilities(ssl(), &ssl_capabilities, + &ssl_capabilities_len)) { + ssl_capabilities_view = + absl::string_view(reinterpret_cast(ssl_capabilities.get()), + ssl_capabilities_len); + } + + // Enable ALPS for the session's ALPN. + SetApplicationSettingsResult alps_result = + SetApplicationSettings(AlpnForVersion(session()->version())); + if (!alps_result.success) { + return ssl_select_cert_error; + } + alps = + alps_result.alps_length > 0 + ? std::string(alps_result.alps_buffer.get(), alps_result.alps_length) + : std::string(); + + if (no_select_cert_if_disconnected_ && + !session()->connection()->connected()) { + QUIC_RELOADABLE_FLAG_COUNT_N(quic_tls_no_select_cert_if_disconnected, 2, 2); + select_cert_status_ = QUIC_FAILURE; + return ssl_select_cert_error; + } + const QuicAsyncStatus status = proof_source_handle_->SelectCertificate( - session()->connection()->self_address(), - session()->connection()->peer_address(), cert_selection_hostname(), + session()->connection()->self_address().Normalized(), + session()->connection()->peer_address().Normalized(), + ssl_capabilities_view, crypto_negotiated_params_->sni, absl::string_view( reinterpret_cast(client_hello->client_hello), client_hello->client_hello_len), - AlpnForVersion(session()->version()), + AlpnForVersion(session()->version()), std::move(alps), set_transport_params_result.quic_transport_params, - set_transport_params_result.early_data_context); + set_transport_params_result.early_data_context, + tls_connection_.ssl_config()); QUICHE_DCHECK_EQ(status, select_cert_status().value()); @@ -842,18 +1008,54 @@ ssl_select_cert_result_t TlsServerHandshaker::EarlySelectCertCallback( } void TlsServerHandshaker::OnSelectCertificateDone( - bool ok, - bool is_sync, - const ProofSource::Chain* chain) { + bool ok, bool is_sync, const ProofSource::Chain* chain, + absl::string_view handshake_hints, absl::string_view ticket_encryption_key, + bool cert_matched_sni, QuicDelayedSSLConfig delayed_ssl_config) { QUIC_DVLOG(1) << "OnSelectCertificateDone. ok:" << ok - << ", is_sync:" << is_sync; + << ", is_sync:" << is_sync + << ", len(handshake_hints):" << handshake_hints.size() + << ", len(ticket_encryption_key):" + << ticket_encryption_key.size(); + absl::optional context_switcher; + if (!is_sync) { + context_switcher.emplace(connection_context()); + } + + QUIC_TRACESTRING(absl::StrCat( + "TLS select certificate done: ok:", ok, + ", len(handshake_hints):", handshake_hints.size(), + ", len(ticket_encryption_key):", ticket_encryption_key.size())); + + ticket_encryption_key_ = std::string(ticket_encryption_key); select_cert_status_ = QUIC_FAILURE; + cert_matched_sni_ = cert_matched_sni; + if (session()->support_client_cert()) { + if (delayed_ssl_config.client_cert_mode.has_value()) { + tls_connection_.SetClientCertMode(*delayed_ssl_config.client_cert_mode); + QUIC_DVLOG(1) << "client_cert_mode after cert selection: " + << client_cert_mode(); + } + } if (ok) { if (chain && !chain->certs.empty()) { tls_connection_.SetCertChain(chain->ToCryptoBuffers().value); + if (!handshake_hints.empty() && + !SSL_set_handshake_hints( + ssl(), reinterpret_cast(handshake_hints.data()), + handshake_hints.size())) { + // If |SSL_set_handshake_hints| fails, the ssl() object will remain + // intact, it is as if we didn't call it. The handshaker will + // continue to compute signature/decrypt ticket as normal. + QUIC_CODE_COUNT(quic_tls_server_set_handshake_hints_failed); + QUIC_DVLOG(1) << "SSL_set_handshake_hints failed"; + } select_cert_status_ = QUIC_SUCCESS; } else { - QUIC_LOG(ERROR) << "No certs provided for host '" << hostname_ << "'"; + QUIC_LOG(ERROR) << "No certs provided for host '" + << crypto_negotiated_params_->sni << "', server_address:" + << session()->connection()->self_address() + << ", client_address:" + << session()->connection()->peer_address(); } } @@ -879,6 +1081,10 @@ void TlsServerHandshaker::OnSelectCertificateDone( } } +bool TlsServerHandshaker::WillNotCallComputeSignature() const { + return SSL_can_release_private_key(ssl()); +} + bool TlsServerHandshaker::ValidateHostname(const std::string& hostname) const { if (!QuicHostnameUtils::IsValidSNI(hostname)) { // TODO(b/151676147): Include this error string in the CONNECTION_CLOSE @@ -928,44 +1134,14 @@ int TlsServerHandshaker::SelectAlpn(const uint8_t** out, alpn_length); } + // TODO(wub): Remove QuicSession::SelectAlpn. QuicSessions should know the + // ALPN on construction. auto selected_alpn = session()->SelectAlpn(alpns); if (selected_alpn == alpns.end()) { QUIC_DLOG(ERROR) << "No known ALPN provided by client"; return SSL_TLSEXT_ERR_NOACK; } - // Enable ALPS for the selected ALPN protocol. - if (GetQuicReloadableFlag(quic_enable_alps_server)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_enable_alps_server); - - const uint8_t* alps_data = nullptr; - size_t alps_length = 0; - std::unique_ptr buffer; - - const std::string& hostname = crypto_negotiated_params_->sni; - std::string accept_ch_value = GetAcceptChValueForOrigin(hostname); - - std::string origin; - if (GetQuicReloadableFlag(quic_alps_include_scheme_in_origin)) { - QUIC_RELOADABLE_FLAG_COUNT(quic_alps_include_scheme_in_origin); - origin = "https://"; - } - origin.append(crypto_negotiated_params_->sni); - - if (!accept_ch_value.empty()) { - AcceptChFrame frame{{{std::move(origin), std::move(accept_ch_value)}}}; - alps_length = HttpEncoder::SerializeAcceptChFrame(frame, &buffer); - alps_data = reinterpret_cast(buffer.get()); - } - - if (SSL_add_application_settings( - ssl(), reinterpret_cast(selected_alpn->data()), - selected_alpn->size(), alps_data, alps_length) != 1) { - QUIC_DLOG(ERROR) << "Failed to enable ALPS"; - return SSL_TLSEXT_ERR_NOACK; - } - } - session()->OnAlpnSelected(*selected_alpn); valid_alpn_received_ = true; *out_len = selected_alpn->size(); @@ -973,4 +1149,39 @@ int TlsServerHandshaker::SelectAlpn(const uint8_t** out, return SSL_TLSEXT_ERR_OK; } +TlsServerHandshaker::SetApplicationSettingsResult +TlsServerHandshaker::SetApplicationSettings(absl::string_view alpn) { + TlsServerHandshaker::SetApplicationSettingsResult result; + const uint8_t* alps_data = nullptr; + + const std::string& hostname = crypto_negotiated_params_->sni; + std::string accept_ch_value = GetAcceptChValueForHostname(hostname); + std::string origin = absl::StrCat("https://", hostname); + uint16_t port = session()->self_address().port(); + if (port != kDefaultPort) { + // This should be rare in production, but useful for test servers. + QUIC_CODE_COUNT(quic_server_alps_non_default_port); + absl::StrAppend(&origin, ":", port); + } + + if (!accept_ch_value.empty()) { + AcceptChFrame frame{{{std::move(origin), std::move(accept_ch_value)}}}; + result.alps_length = + HttpEncoder::SerializeAcceptChFrame(frame, &result.alps_buffer); + alps_data = reinterpret_cast(result.alps_buffer.get()); + } + + if (SSL_add_application_settings( + ssl(), reinterpret_cast(alpn.data()), alpn.size(), + alps_data, result.alps_length) != 1) { + QUIC_DLOG(ERROR) << "Failed to enable ALPS"; + result.success = false; + } else { + result.success = true; + } + return result; +} + +SSL* TlsServerHandshaker::GetSsl() const { return ssl(); } + } // namespace quic diff --git a/gquiche/quic/core/tls_server_handshaker.h b/gquiche/quic/core/tls_server_handshaker.h index be2e3efe..99d1ec4f 100644 --- a/gquiche/quic/core/tls_server_handshaker.h +++ b/gquiche/quic/core/tls_server_handshaker.h @@ -19,6 +19,8 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/tls_handshaker.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/quic/platform/api/quic_flag_utils.h" +#include "gquiche/quic/platform/api/quic_flags.h" namespace quic { @@ -56,11 +58,16 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker void OnConnectionClosed(QuicErrorCode error, ConnectionCloseSource source) override; void OnHandshakeDoneReceived() override; - std::string GetAddressToken() const override; + std::string GetAddressToken( + const CachedNetworkParameters* cached_network_params) const override; bool ValidateAddressToken(absl::string_view token) const override; void OnNewTokenReceived(absl::string_view token) override; bool ShouldSendExpectCTHeader() const override; + bool DidCertMatchSni() const override; const ProofSource::Details* ProofSourceDetails() const override; + bool ExportKeyingMaterial(absl::string_view label, absl::string_view context, + size_t result_len, std::string* result) override; + SSL* GetSsl() const override; // From QuicCryptoServerStreamBase and TlsHandshaker ssl_early_data_reason_t EarlyDataReason() const override; @@ -81,12 +88,20 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker const SSL_CIPHER* cipher, const std::vector& write_secret) override; - // Called with normalized SNI hostname as |origin|. Return value will be sent - // in an ACCEPT_CH frame in the TLS ALPS extension, unless empty. - virtual std::string GetAcceptChValueForOrigin( - const std::string& origin) const; + // Called with normalized SNI hostname as |hostname|. Return value will be + // sent in an ACCEPT_CH frame in the TLS ALPS extension, unless empty. + virtual std::string GetAcceptChValueForHostname( + const std::string& hostname) const; + + // Get the ClientCertMode that is currently in effect on this handshaker. + ClientCertMode client_cert_mode() const { + return tls_connection_.ssl_config().client_cert_mode; + } protected: + // Override for tracing. + void InfoCallback(int type, int value) override; + // Creates a proof source handle for selecting cert and computing signature. virtual std::unique_ptr MaybeCreateProofSourceHandle(); @@ -96,14 +111,6 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker virtual bool ValidateHostname(const std::string& hostname) const; - // The hostname to be used to select certificates and compute signatures. - // The function should only be called after a successful ValidateHostname(). - const std::string& cert_selection_hostname() const { - return use_normalized_sni_for_cert_selection_ - ? crypto_negotiated_params_->sni - : hostname_; - } - const TlsConnection* tls_connection() const override { return &tls_connection_; } @@ -172,9 +179,11 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker bool HasValidSignature(size_t max_signature_size) const; // ProofSourceHandleCallback implementation: - void OnSelectCertificateDone(bool ok, - bool is_sync, - const ProofSource::Chain* chain) override; + void OnSelectCertificateDone( + bool ok, bool is_sync, const ProofSource::Chain* chain, + absl::string_view handshake_hints, + absl::string_view ticket_encryption_key, bool cert_matched_sni, + QuicDelayedSSLConfig delayed_ssl_config) override; void OnComputeSignatureDone( bool ok, @@ -182,6 +191,14 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker std::string signature, std::unique_ptr details) override; + void set_encryption_established(bool encryption_established) { + encryption_established_ = encryption_established; + } + + bool WillNotCallComputeSignature() const override; + + void SetIgnoreTicketOpen(bool value) { ignore_ticket_open_ = value; } + private: class QUIC_EXPORT_PRIVATE DecryptCallback : public ProofSource::DecryptCallback { @@ -192,6 +209,11 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker // If called, Cancel causes the pending callback to be a no-op. void Cancel(); + // Return true if either + // - Cancel() has been called. + // - Run() has been called, or is in the middle of it. + bool IsDone() const { return handshaker_ == nullptr; } + private: TlsServerHandshaker* handshaker_; }; @@ -206,20 +228,22 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker ~DefaultProofSourceHandle() override; - // Cancel the pending signature operation, if any. - void CancelPendingOperation() override; + // Close the handle. Cancel the pending signature operation, if any. + void CloseHandle() override; // Delegates to proof_source_->GetCertChain. // Returns QUIC_SUCCESS or QUIC_FAILURE. Never returns QUIC_PENDING. QuicAsyncStatus SelectCertificate( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, + absl::string_view ssl_capabilities, const std::string& hostname, absl::string_view client_hello, const std::string& alpn, + absl::optional alps, const std::vector& quic_transport_params, - const absl::optional>& early_data_context) - override; + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) override; // Delegates to proof_source_->ComputeTlsSignature. // Returns QUIC_SUCCESS, QUIC_FAILURE or QUIC_PENDING. @@ -247,9 +271,13 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker // Operation has been canceled, or Run has been called. return; } - handle_->signature_callback_ = nullptr; - if (handle_->handshaker_ != nullptr) { - handle_->handshaker_->OnComputeSignatureDone( + + DefaultProofSourceHandle* handle = handle_; + handle_ = nullptr; + + handle->signature_callback_ = nullptr; + if (handle->handshaker_ != nullptr) { + handle->handshaker_->OnComputeSignatureDone( ok, is_sync_, std::move(signature), std::move(details)); } } @@ -284,11 +312,22 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker bool ProcessTransportParameters(const SSL_CLIENT_HELLO* client_hello, std::string* error_details); + struct QUIC_NO_EXPORT SetApplicationSettingsResult { + bool success = false; + std::unique_ptr alps_buffer; + size_t alps_length = 0; + }; + SetApplicationSettingsResult SetApplicationSettings(absl::string_view alpn); + QuicConnectionStats& connection_stats() { return session()->connection()->mutable_stats(); } QuicTime now() const { return session()->GetClock()->Now(); } + QuicConnectionContext* connection_context() { + return session()->connection()->context(); + } + std::unique_ptr proof_source_handle_; ProofSource* proof_source_; @@ -306,10 +345,12 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker // indicates that the client attempted a resumption. bool ticket_received_ = false; + // Force SessionTicketOpen to return ssl_ticket_aead_ignore_ticket if called. + bool ignore_ticket_open_ = false; + // nullopt means select cert hasn't started. absl::optional select_cert_status_; - std::string hostname_; std::string cert_verify_sig_; std::unique_ptr proof_source_details_; @@ -321,15 +362,23 @@ class QUIC_EXPORT_PRIVATE TlsServerHandshaker // Pre-shared key used during the handshake. std::string pre_shared_key_; + // (optional) Key to use for encrypting TLS resumption tickets. + std::string ticket_encryption_key_; + HandshakeState state_ = HANDSHAKE_START; bool encryption_established_ = false; bool valid_alpn_received_ = false; QuicReferenceCountedPointer crypto_negotiated_params_; TlsServerConnection tls_connection_; - const bool use_normalized_sni_for_cert_selection_ = - GetQuicReloadableFlag(quic_tls_use_normalized_sni_for_cert_selectioon); const QuicCryptoServerConfig* crypto_config_; // Unowned. + // The last received CachedNetworkParameters from a validated address token. + mutable std::unique_ptr + last_received_cached_network_params_; + + bool cert_matched_sni_ = false; + const bool no_select_cert_if_disconnected_ = + GetQuicReloadableFlag(quic_tls_no_select_cert_if_disconnected); }; } // namespace quic diff --git a/gquiche/quic/core/tls_server_handshaker_test.cc b/gquiche/quic/core/tls_server_handshaker_test.cc index 2b1383c6..1a0ff0ef 100644 --- a/gquiche/quic/core/tls_server_handshaker_test.cc +++ b/gquiche/quic/core/tls_server_handshaker_test.cc @@ -2,20 +2,23 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +#include "gquiche/quic/core/tls_server_handshaker.h" + #include #include #include #include "absl/base/macros.h" #include "absl/strings/string_view.h" +#include "gquiche/quic/core/crypto/client_proof_source.h" #include "gquiche/quic/core/crypto/proof_source.h" #include "gquiche/quic/core/crypto/quic_random.h" #include "gquiche/quic/core/quic_crypto_client_stream.h" #include "gquiche/quic/core/quic_session.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/core/tls_client_handshaker.h" -#include "gquiche/quic/core/tls_server_handshaker.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_test.h" @@ -23,8 +26,10 @@ #include "gquiche/quic/test_tools/failing_proof_source.h" #include "gquiche/quic/test_tools/fake_proof_source.h" #include "gquiche/quic/test_tools/fake_proof_source_handle.h" +#include "gquiche/quic/test_tools/quic_config_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" #include "gquiche/quic/test_tools/simple_session_cache.h" +#include "gquiche/quic/test_tools/test_certificates.h" #include "gquiche/quic/test_tools/test_ticket_crypter.h" namespace quic { @@ -44,6 +49,21 @@ namespace { const char kServerHostname[] = "test.example.com"; const uint16_t kServerPort = 443; +QuicReferenceCountedPointer TestClientCertChain() { + return QuicReferenceCountedPointer( + new ClientProofSource::Chain({std::string(kTestCertificate)})); +} + +CertificatePrivateKey TestClientCertPrivateKey() { + CBS private_key_cbs; + CBS_init(&private_key_cbs, + reinterpret_cast(kTestCertificatePrivateKey.data()), + kTestCertificatePrivateKey.size()); + + return CertificatePrivateKey( + bssl::UniquePtr(EVP_parse_private_key(&private_key_cbs))); +} + struct TestParams { ParsedQuicVersion version; bool disable_resumption; @@ -76,6 +96,10 @@ class TestTlsServerHandshaker : public TlsServerHandshaker { ON_CALL(*this, MaybeCreateProofSourceHandle()) .WillByDefault(testing::Invoke( this, &TestTlsServerHandshaker::RealMaybeCreateProofSourceHandle)); + + ON_CALL(*this, OverrideQuicConfigDefaults(_)) + .WillByDefault(testing::Invoke( + this, &TestTlsServerHandshaker::RealOverrideQuicConfigDefaults)); } MOCK_METHOD(std::unique_ptr, @@ -83,15 +107,20 @@ class TestTlsServerHandshaker : public TlsServerHandshaker { (), (override)); + MOCK_METHOD(void, OverrideQuicConfigDefaults, (QuicConfig * config), + (override)); + void SetupProofSourceHandle( FakeProofSourceHandle::Action select_cert_action, - FakeProofSourceHandle::Action compute_signature_action) { + FakeProofSourceHandle::Action compute_signature_action, + QuicDelayedSSLConfig dealyed_ssl_config = QuicDelayedSSLConfig()) { EXPECT_CALL(*this, MaybeCreateProofSourceHandle()) - .WillOnce(testing::Invoke( - [this, select_cert_action, compute_signature_action]() { + .WillOnce( + testing::Invoke([this, select_cert_action, compute_signature_action, + dealyed_ssl_config]() { auto handle = std::make_unique( proof_source_, this, select_cert_action, - compute_signature_action); + compute_signature_action, dealyed_ssl_config); fake_proof_source_handle_ = handle.get(); return handle; })); @@ -101,16 +130,34 @@ class TestTlsServerHandshaker : public TlsServerHandshaker { return fake_proof_source_handle_; } + bool received_client_cert() const { return received_client_cert_; } + + using TlsServerHandshaker::AdvanceHandshake; using TlsServerHandshaker::expected_ssl_error; + protected: + QuicAsyncStatus VerifyCertChain( + const std::vector& certs, std::string* error_details, + std::unique_ptr* details, uint8_t* out_alert, + std::unique_ptr callback) override { + received_client_cert_ = true; + return TlsServerHandshaker::VerifyCertChain(certs, error_details, details, + out_alert, std::move(callback)); + } + private: std::unique_ptr RealMaybeCreateProofSourceHandle() { return TlsServerHandshaker::MaybeCreateProofSourceHandle(); } + void RealOverrideQuicConfigDefaults(QuicConfig* config) { + return TlsServerHandshaker::OverrideQuicConfigDefaults(config); + } + // Owned by TlsServerHandshaker. FakeProofSourceHandle* fake_proof_source_handle_ = nullptr; ProofSource* proof_source_ = nullptr; + bool received_client_cert_ = false; }; class TlsServerHandshakerTestSession : public TestQuicSpdyServerSession { @@ -184,6 +231,7 @@ class TlsServerHandshakerTest : public QuicTestWithParam { new TlsServerHandshakerTestSession( server_connection_, DefaultQuicConfig(), supported_versions_, server_crypto_config_.get(), &server_compressed_certs_cache_); + server_session->set_client_cert_mode(initial_client_cert_mode_); server_session->Initialize(); // We advance the clock initially because the default time is zero and the @@ -368,6 +416,7 @@ class TlsServerHandshakerTest : public QuicTestWithParam { std::unique_ptr server_crypto_config_; QuicCompressedCertsCache server_compressed_certs_cache_; QuicServerId server_id_; + ClientCertMode initial_client_cert_mode_ = ClientCertMode::kNone; // Client state. PacketSavingConnection* client_connection_; @@ -557,23 +606,30 @@ TEST_P(TlsServerHandshakerTest, HostnameForCertSelectionAndComputeSignature) { EXPECT_EQ(server_stream()->crypto_negotiated_params().sni, "test.example.com"); - if (GetQuicReloadableFlag(quic_tls_use_normalized_sni_for_cert_selectioon)) { - EXPECT_EQ(last_select_cert_args().hostname, "test.example.com"); - EXPECT_EQ(last_compute_signature_args().hostname, "test.example.com"); - } else { - EXPECT_EQ(last_select_cert_args().hostname, "tEsT.EXAMPLE.CoM"); - EXPECT_EQ(last_compute_signature_args().hostname, "tEsT.EXAMPLE.CoM"); - } + EXPECT_EQ(last_select_cert_args().hostname, "test.example.com"); + EXPECT_EQ(last_compute_signature_args().hostname, "test.example.com"); +} + +TEST_P(TlsServerHandshakerTest, SSLConfigForCertSelection) { + InitializeServerWithFakeProofSourceHandle(); + + // Disable early data. + server_session_->set_early_data_enabled(false); + + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + EXPECT_FALSE(last_select_cert_args().ssl_config.early_data_enabled); } TEST_P(TlsServerHandshakerTest, ConnectionClosedOnTlsError) { - if (GetQuicReloadableFlag(quic_send_tls_crypto_error_code)) { - EXPECT_CALL(*server_connection_, - CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _)); - } else { - EXPECT_CALL(*server_connection_, - CloseConnection(QUIC_HANDSHAKE_FAILED, _, _)); - } + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, _, _, _)); // Send a zero-length ClientHello from client to server. char bogus_handshake_message[] = { @@ -602,23 +658,14 @@ TEST_P(TlsServerHandshakerTest, ClientSendingBadALPN) { const std::string kTestBadClientAlpn = "bad-client-alpn"; EXPECT_CALL(*client_session_, GetAlpnsToOffer()) .WillOnce(Return(std::vector({kTestBadClientAlpn}))); - if (GetQuicReloadableFlag(quic_send_tls_crypto_error_code)) { - EXPECT_CALL( - *server_connection_, - CloseConnection( - QUIC_HANDSHAKE_FAILED, - static_cast(CRYPTO_ERROR_FIRST + 120), - "TLS handshake failure (ENCRYPTION_INITIAL) 120: " - "no application protocol", - _)); - } else { - EXPECT_CALL( - *server_connection_, - CloseConnection(QUIC_HANDSHAKE_FAILED, - "TLS handshake failure (ENCRYPTION_INITIAL) 120: " - "no application protocol", - _)); - } + + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_HANDSHAKE_FAILED, + static_cast( + CRYPTO_ERROR_FIRST + 120), + "TLS handshake failure (ENCRYPTION_INITIAL) 120: " + "no application protocol", + _)); AdvanceHandshakeWithFakeClient(); @@ -710,6 +757,41 @@ TEST_P(TlsServerHandshakerTest, ResumptionWithAsyncDecryptCallback) { EXPECT_TRUE(server_stream()->ResumptionAttempted()); } +TEST_P(TlsServerHandshakerTest, AdvanceHandshakeDuringAsyncDecryptCallback) { + if (GetParam().disable_resumption) { + return; + } + + // Do the first handshake + InitializeFakeClient(); + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + ticket_crypter_->SetRunCallbacksAsync(true); + // Now do another handshake + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + InitializeFakeClient(); + + AdvanceHandshakeWithFakeClient(); + + // Ensure an async DecryptCallback is now pending. + ASSERT_EQ(ticket_crypter_->NumPendingCallbacks(), 1u); + + { + QuicConnection::ScopedPacketFlusher flusher(server_connection_); + server_handshaker_->AdvanceHandshake(); + } + + // This will delete |server_handshaker_|. + server_session_ = nullptr; + + ticket_crypter_->RunPendingCallback(0); // Should not crash. +} + TEST_P(TlsServerHandshakerTest, ResumptionWithFailingDecryptCallback) { if (GetParam().disable_resumption) { return; @@ -818,6 +900,179 @@ TEST_P(TlsServerHandshakerTest, ZeroRttRejectOnApplicationStateChange) { EXPECT_FALSE(server_stream()->IsZeroRtt()); } +TEST_P(TlsServerHandshakerTest, RequestClientCert) { + auto client_proof_source = std::make_unique(); + ASSERT_TRUE(client_proof_source->AddCertAndKey({"*"}, TestClientCertChain(), + TestClientCertPrivateKey())); + client_crypto_config_->set_proof_source(std::move(client_proof_source)); + InitializeFakeClient(); + + initial_client_cert_mode_ = ClientCertMode::kRequest; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + if (GetQuicRestartFlag(quic_tls_server_support_client_cert)) { + EXPECT_TRUE(server_handshaker_->received_client_cert()); + } else { + EXPECT_FALSE(server_handshaker_->received_client_cert()); + } +} + +TEST_P(TlsServerHandshakerTest, RequestClientCertByDelayedSslConfig) { + auto client_proof_source = std::make_unique(); + ASSERT_TRUE(client_proof_source->AddCertAndKey({"*"}, TestClientCertChain(), + TestClientCertPrivateKey())); + client_crypto_config_->set_proof_source(std::move(client_proof_source)); + InitializeFakeClient(); + + QuicDelayedSSLConfig delayed_ssl_config; + delayed_ssl_config.client_cert_mode = ClientCertMode::kRequest; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + delayed_ssl_config); + + AdvanceHandshakeWithFakeClient(); + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + if (GetQuicRestartFlag(quic_tls_server_support_client_cert)) { + EXPECT_TRUE(server_handshaker_->received_client_cert()); + } else { + EXPECT_FALSE(server_handshaker_->received_client_cert()); + } +} + +TEST_P(TlsServerHandshakerTest, RequestClientCert_NoCert) { + initial_client_cert_mode_ = ClientCertMode::kRequest; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + EXPECT_FALSE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, RequestAndRequireClientCert) { + auto client_proof_source = std::make_unique(); + ASSERT_TRUE(client_proof_source->AddCertAndKey({"*"}, TestClientCertChain(), + TestClientCertPrivateKey())); + client_crypto_config_->set_proof_source(std::move(client_proof_source)); + InitializeFakeClient(); + + initial_client_cert_mode_ = ClientCertMode::kRequire; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + + if (GetQuicRestartFlag(quic_tls_server_support_client_cert)) { + EXPECT_TRUE(server_handshaker_->received_client_cert()); + } else { + EXPECT_FALSE(server_handshaker_->received_client_cert()); + } +} + +TEST_P(TlsServerHandshakerTest, RequestAndRequireClientCertByDelayedSslConfig) { + auto client_proof_source = std::make_unique(); + ASSERT_TRUE(client_proof_source->AddCertAndKey({"*"}, TestClientCertChain(), + TestClientCertPrivateKey())); + client_crypto_config_->set_proof_source(std::move(client_proof_source)); + InitializeFakeClient(); + + QuicDelayedSSLConfig delayed_ssl_config; + delayed_ssl_config.client_cert_mode = ClientCertMode::kRequire; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_ASYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + delayed_ssl_config); + + AdvanceHandshakeWithFakeClient(); + ASSERT_TRUE( + server_handshaker_->fake_proof_source_handle()->HasPendingOperation()); + server_handshaker_->fake_proof_source_handle()->CompletePendingOperation(); + + CompleteCryptoHandshake(); + ExpectHandshakeSuccessful(); + if (GetQuicRestartFlag(quic_tls_server_support_client_cert)) { + EXPECT_TRUE(server_handshaker_->received_client_cert()); + } else { + EXPECT_FALSE(server_handshaker_->received_client_cert()); + } +} + +TEST_P(TlsServerHandshakerTest, RequestAndRequireClientCert_NoCert) { + initial_client_cert_mode_ = ClientCertMode::kRequire; + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action::DELEGATE_SYNC, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + DELEGATE_SYNC); + + if (GetQuicRestartFlag(quic_tls_server_support_client_cert)) { + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_TLS_CERTIFICATE_REQUIRED, _, _, _)); + } + AdvanceHandshakeWithFakeClient(); + AdvanceHandshakeWithFakeClient(); + EXPECT_FALSE(server_handshaker_->received_client_cert()); +} + +TEST_P(TlsServerHandshakerTest, CloseConnectionBeforeSelectCert) { + InitializeServerWithFakeProofSourceHandle(); + server_handshaker_->SetupProofSourceHandle( + /*select_cert_action=*/FakeProofSourceHandle::Action:: + FAIL_SYNC_DO_NOT_CHECK_CLOSED, + /*compute_signature_action=*/FakeProofSourceHandle::Action:: + FAIL_SYNC_DO_NOT_CHECK_CLOSED); + + EXPECT_CALL(*server_handshaker_, OverrideQuicConfigDefaults(_)) + .WillOnce(testing::Invoke([](QuicConfig* config) { + QuicConfigPeer::SetReceivedMaxUnidirectionalStreams(config, + /*max_streams=*/0); + })); + + EXPECT_CALL(*server_connection_, + CloseConnection(QUIC_ZERO_RTT_RESUMPTION_LIMIT_REDUCED, _, _)) + .WillOnce(testing::Invoke( + [this](QuicErrorCode error, const std::string& details, + ConnectionCloseBehavior connection_close_behavior) { + server_connection_->ReallyCloseConnection( + error, details, connection_close_behavior); + ASSERT_FALSE(server_connection_->connected()); + })); + + AdvanceHandshakeWithFakeClient(); + if (!GetQuicReloadableFlag(quic_tls_no_select_cert_if_disconnected)) { + // SelectCertificate is called when flag is false. + EXPECT_FALSE(server_handshaker_->fake_proof_source_handle() + ->all_select_cert_args() + .empty()); + return; + } + + EXPECT_TRUE(server_handshaker_->fake_proof_source_handle() + ->all_select_cert_args() + .empty()); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/core/uber_received_packet_manager.cc b/gquiche/quic/core/uber_received_packet_manager.cc index e00ca688..27494511 100644 --- a/gquiche/quic/core/uber_received_packet_manager.cc +++ b/gquiche/quic/core/uber_received_packet_manager.cc @@ -120,14 +120,11 @@ void UberReceivedPacketManager::EnableMultiplePacketNumberSpacesSupport( } // In IETF QUIC, the peer is expected to acknowledge packets in Initial and // Handshake packets with minimal delay. - if (!GetQuicReloadableFlag(quic_delay_initial_ack) || - perspective == Perspective::IS_CLIENT) { + if (perspective == Perspective::IS_CLIENT) { // Delay the first server ACK, because server ACKs are padded to // full size and count towards the amplification limit. received_packet_managers_[INITIAL_DATA].set_local_max_ack_delay( kAlarmGranularity); - } else { - QUIC_RELOADABLE_FLAG_COUNT(quic_delay_initial_ack); } received_packet_managers_[HANDSHAKE_DATA].set_local_max_ack_delay( kAlarmGranularity); diff --git a/gquiche/quic/core/uber_received_packet_manager_test.cc b/gquiche/quic/core/uber_received_packet_manager_test.cc index c9b91e8f..386cd12a 100644 --- a/gquiche/quic/core/uber_received_packet_manager_test.cc +++ b/gquiche/quic/core/uber_received_packet_manager_test.cc @@ -368,7 +368,6 @@ TEST_F(UberReceivedPacketManagerTest, EXPECT_FALSE(HasPendingAck()); QuicConfig config; QuicTagVector connection_options; - connection_options.push_back(kACKD); // No limit on the number of packets received before sending an ack. connection_options.push_back(kAKDU); config.SetConnectionOptionsToSend(connection_options); @@ -478,17 +477,10 @@ TEST_F(UberReceivedPacketManagerTest, AckSendingDifferentPacketNumberSpaces) { MaybeUpdateAckTimeout(kInstigateAck, ENCRYPTION_INITIAL, 3); EXPECT_TRUE(HasPendingAck()); // Delayed ack is scheduled. - if (GetQuicReloadableFlag(quic_delay_initial_ack)) { - CheckAckTimeout(clock_.ApproximateNow() + - QuicTime::Delta::FromMilliseconds(25)); - // Send delayed handshake data ACK. - clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(25)); - } else { - CheckAckTimeout(clock_.ApproximateNow() + - QuicTime::Delta::FromMilliseconds(1)); - // Send delayed handshake data ACK. - clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(1)); - } + CheckAckTimeout(clock_.ApproximateNow() + + QuicTime::Delta::FromMilliseconds(25)); + // Send delayed handshake data ACK. + clock_.AdvanceTime(QuicTime::Delta::FromMilliseconds(25)); CheckAckTimeout(clock_.ApproximateNow()); EXPECT_FALSE(HasPendingAck()); diff --git a/gquiche/quic/core/web_transport_interface.h b/gquiche/quic/core/web_transport_interface.h index b1375737..bcdd60cf 100644 --- a/gquiche/quic/core/web_transport_interface.h +++ b/gquiche/quic/core/web_transport_interface.h @@ -16,6 +16,7 @@ #include "gquiche/quic/core/quic_datagram_queue.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_export.h" +#include "gquiche/spdy/core/spdy_header_block.h" namespace quic { @@ -28,6 +29,14 @@ class QUIC_EXPORT_PRIVATE WebTransportStreamVisitor { virtual void OnCanRead() = 0; // Called whenever the stream is not write-blocked and can accept new data. virtual void OnCanWrite() = 0; + + // Called when RESET_STREAM is received for the stream. + virtual void OnResetStreamReceived(WebTransportStreamError error) = 0; + // Called when STOP_SENDING is received for the stream. + virtual void OnStopSendingReceived(WebTransportStreamError error) = 0; + // Called when the write side of the stream is closed and all of the data sent + // has been acknowledged ("Data Recvd" state of RFC 9000). + virtual void OnWriteSideInDataRecvdState() = 0; }; // A stream (either bidirectional or unidirectional) that is contained within a @@ -66,9 +75,9 @@ class QUIC_EXPORT_PRIVATE WebTransportStream { virtual QuicStreamId GetStreamId() const = 0; // Resets the stream with the specified error code. - // TODO(b/184048994): change the error code type based on IETF consensus. - virtual void ResetWithUserCode(QuicRstStreamErrorCode error) = 0; + virtual void ResetWithUserCode(WebTransportStreamError error) = 0; virtual void ResetDueToInternalError() = 0; + virtual void SendStopSending(WebTransportStreamError error) = 0; // Called when the owning object has been garbage-collected. virtual void MaybeResetDueToStreamObjectGone() = 0; @@ -84,7 +93,11 @@ class QUIC_EXPORT_PRIVATE WebTransportVisitor { // Notifies the visitor when the session is ready to exchange application // data. - virtual void OnSessionReady() = 0; + virtual void OnSessionReady(const spdy::SpdyHeaderBlock& headers) = 0; + + // Notifies the visitor when the session has been closed. + virtual void OnSessionClosed(WebTransportSessionError error_code, + const std::string& error_message) = 0; // Notifies the visitor when a new stream has been received. The stream in // question can be retrieved using AcceptIncomingBidirectionalStream() or @@ -105,6 +118,11 @@ class QUIC_EXPORT_PRIVATE WebTransportSession { public: virtual ~WebTransportSession() {} + // Closes the WebTransport session in question with the specified |error_code| + // and |error_message|. + virtual void CloseSession(WebTransportSessionError error_code, + absl::string_view error_message) = 0; + // Return the earliest incoming stream that has been received by the session // but has not been accepted. Returns nullptr if there are no incoming // streams. @@ -120,6 +138,9 @@ class QUIC_EXPORT_PRIVATE WebTransportSession { virtual WebTransportStream* OpenOutgoingUnidirectionalStream() = 0; virtual MessageStatus SendOrQueueDatagram(QuicMemSlice datagram) = 0; + // Returns a conservative estimate of the largest datagram size that the + // session would be able to send. + virtual QuicByteCount GetMaxDatagramSize() const = 0; // Sets the largest duration that a datagram can spend in the queue before // being silently dropped. virtual void SetDatagramMaxTimeInQueue(QuicTime::Delta max_time_in_queue) = 0; diff --git a/gquiche/quic/masque/masque_client_bin.cc b/gquiche/quic/masque/masque_client_bin.cc index b7607e29..2cb9ff16 100644 --- a/gquiche/quic/masque/masque_client_bin.cc +++ b/gquiche/quic/masque/masque_client_bin.cc @@ -8,9 +8,11 @@ // e.g.: masque_client $PROXY_HOST:$PROXY_PORT $URL1 $URL2 #include +#include #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "url/third_party/mozilla/url_parse.h" #include "gquiche/quic/core/quic_server_id.h" #include "gquiche/quic/masque/masque_client_tools.h" #include "gquiche/quic/masque/masque_encapsulated_epoll_client.h" @@ -21,8 +23,6 @@ #include "gquiche/quic/platform/api/quic_socket_address.h" #include "gquiche/quic/platform/api/quic_system_event_loop.h" #include "gquiche/quic/tools/fake_proof_verifier.h" -#include "gquiche/quic/tools/quic_url.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" DEFINE_QUIC_COMMAND_LINE_FLAG(bool, disable_certificate_verification, @@ -43,34 +43,44 @@ int RunMasqueClient(int argc, char* argv[]) { QuicSystemEventLoop event_loop("masque_client"); const char* usage = "Usage: masque_client [options] "; - // The first non-flag argument is the MASQUE server. All subsequent ones are - // interpreted as URLs to fetch via the MASQUE server. + // The first non-flag argument is the URI template of the MASQUE server. + // All subsequent ones are interpreted as URLs to fetch via the MASQUE server. + // Note that the URI template expansion currently only supports string + // replacement of {target_host} and {target_port}, not + // {?target_host,target_port}. std::vector urls = QuicParseCommandLineFlags(usage, argc, argv); if (urls.empty()) { QuicPrintCommandLineFlagHelp(usage); return 1; } - SetQuicReloadableFlag(quic_h3_datagram, true); - const bool disable_certificate_verification = GetQuicFlag(FLAGS_disable_certificate_verification); QuicEpollServer epoll_server; - QuicUrl masque_url(urls[0], "https"); - if (masque_url.host().empty()) { - masque_url = QuicUrl(absl::StrCat("https://", urls[0]), "https"); + std::string uri_template = urls[0]; + if (!absl::StrContains(uri_template, '/')) { + // Allow passing in authority instead of URI template. + uri_template = + absl::StrCat("https://", uri_template, "/{target_host}/{target_port}/"); } - if (masque_url.host().empty()) { - std::cerr << "Failed to parse MASQUE server address \"" << urls[0] << "\"" + url::Parsed parsed_uri_template; + url::ParseStandardURL(uri_template.c_str(), uri_template.length(), + &parsed_uri_template); + if (!parsed_uri_template.scheme.is_nonempty() || + !parsed_uri_template.host.is_nonempty() || + !parsed_uri_template.path.is_nonempty()) { + std::cerr << "Failed to parse MASQUE URI template \"" << urls[0] << "\"" << std::endl; return 1; } + std::string host = uri_template.substr(parsed_uri_template.host.begin, + parsed_uri_template.host.len); std::unique_ptr proof_verifier; if (disable_certificate_verification) { proof_verifier = std::make_unique(); } else { - proof_verifier = CreateDefaultProofVerifier(masque_url.host()); + proof_verifier = CreateDefaultProofVerifier(host); } MasqueMode masque_mode = MasqueMode::kOpen; std::string mode_string = GetQuicFlag(FLAGS_masque_mode); @@ -81,8 +91,7 @@ int RunMasqueClient(int argc, char* argv[]) { return 1; } std::unique_ptr masque_client = MasqueEpollClient::Create( - masque_url.host(), masque_url.port(), masque_mode, &epoll_server, - std::move(proof_verifier)); + uri_template, masque_mode, &epoll_server, std::move(proof_verifier)); if (masque_client == nullptr) { return 1; } diff --git a/gquiche/quic/masque/masque_client_session.cc b/gquiche/quic/masque/masque_client_session.cc index 668326de..5cfd3a94 100644 --- a/gquiche/quic/masque/masque_client_session.cc +++ b/gquiche/quic/masque/masque_client_session.cc @@ -3,30 +3,33 @@ // found in the LICENSE file. #include "gquiche/quic/masque/masque_client_session.h" + +#include + #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "url/url_canon.h" #include "gquiche/quic/core/http/spdy_utils.h" #include "gquiche/quic/core/quic_data_reader.h" #include "gquiche/quic/core/quic_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/quic/platform/api/quic_socket_address.h" +#include "gquiche/quic/tools/quic_url.h" +#include "gquiche/common/platform/api/quiche_url_utils.h" namespace quic { MasqueClientSession::MasqueClientSession( - MasqueMode masque_mode, - const QuicConfig& config, - const ParsedQuicVersionVector& supported_versions, - QuicConnection* connection, - const QuicServerId& server_id, + MasqueMode masque_mode, const std::string& uri_template, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, - QuicClientPushPromiseIndex* push_promise_index, - Owner* owner) - : QuicSpdyClientSession(config, - supported_versions, - connection, - server_id, - crypto_config, - push_promise_index), + QuicClientPushPromiseIndex* push_promise_index, Owner* owner) + : QuicSpdyClientSession(config, supported_versions, connection, server_id, + crypto_config, push_promise_index), masque_mode_(masque_mode), + uri_template_(uri_template), owner_(owner), compression_engine_(this) {} @@ -86,6 +89,51 @@ MasqueClientSession::GetOrCreateConnectUdpClientState( } } // No CONNECT-UDP request found, create a new one. + + url::Parsed parsed_uri_template; + url::ParseStandardURL(uri_template_.c_str(), uri_template_.length(), + &parsed_uri_template); + if (!parsed_uri_template.path.is_nonempty()) { + QUIC_BUG(bad URI template path) + << "Cannot parse path from URI template " << uri_template_; + return nullptr; + } + std::string path = uri_template_.substr(parsed_uri_template.path.begin, + parsed_uri_template.path.len); + if (parsed_uri_template.query.is_valid()) { + absl::StrAppend(&path, "?", + uri_template_.substr(parsed_uri_template.query.begin, + parsed_uri_template.query.len)); + } + absl::flat_hash_map parameters; + parameters["target_host"] = target_server_address.host().ToString(); + parameters["target_port"] = absl::StrCat(target_server_address.port()); + std::string expanded_path; + absl::flat_hash_set vars_found; + bool expanded = + quiche::ExpandURITemplate(path, parameters, &expanded_path, &vars_found); + if (!expanded || vars_found.find("target_host") == vars_found.end() || + vars_found.find("target_port") == vars_found.end()) { + QUIC_DLOG(ERROR) << "Failed to expand URI template \"" << uri_template_ + << "\" for " << target_server_address; + return nullptr; + } + + url::Component expanded_path_component(0, expanded_path.length()); + url::RawCanonOutput<1024> canonicalized_path_output; + url::Component canonicalized_path_component; + bool canonicalized = url::CanonicalizePath( + expanded_path.c_str(), expanded_path_component, + &canonicalized_path_output, &canonicalized_path_component); + if (!canonicalized || !canonicalized_path_component.is_nonempty()) { + QUIC_DLOG(ERROR) << "Failed to canonicalize URI template \"" + << uri_template_ << "\" for " << target_server_address; + return nullptr; + } + std::string canonicalized_path( + canonicalized_path_output.data() + canonicalized_path_component.begin, + canonicalized_path_component.len); + QuicSpdyClientStream* stream = CreateOutgoingBidirectionalStream(); if (stream == nullptr) { // Stream flow control limits prevented us from opening a new stream. @@ -93,19 +141,26 @@ MasqueClientSession::GetOrCreateConnectUdpClientState( return nullptr; } - QuicDatagramFlowId flow_id = GetNextDatagramFlowId(); + QuicUrl url(uri_template_); + std::string scheme = url.scheme(); + std::string authority = url.HostPort(); QUIC_DLOG(INFO) << "Sending CONNECT-UDP request for " << target_server_address - << " using flow ID " << flow_id << " on stream " - << stream->id(); + << " on stream " << stream->id() << " scheme=\"" << scheme + << "\" authority=\"" << authority << "\" path=\"" + << canonicalized_path << "\""; // Send the request. spdy::Http2HeaderBlock headers; - headers[":method"] = "CONNECT-UDP"; - headers[":scheme"] = "masque"; - headers[":path"] = "/"; - headers[":authority"] = target_server_address.ToString(); - SpdyUtils::AddDatagramFlowIdHeader(&headers, flow_id); + headers[":method"] = "CONNECT"; + headers[":protocol"] = "connect-udp"; + headers[":scheme"] = scheme; + headers[":authority"] = authority; + headers[":path"] = canonicalized_path; + headers["connect-udp-version"] = "6"; + if (http_datagram_support() == HttpDatagramSupport::kDraft00) { + SpdyUtils::AddDatagramFlowIdHeader(&headers, stream->id()); + } size_t bytes_sent = stream->SendRequest(std::move(headers), /*body=*/"", /*fin=*/false); if (bytes_sent == 0) { @@ -113,16 +168,16 @@ MasqueClientSession::GetOrCreateConnectUdpClientState( return nullptr; } + absl::optional context_id; connect_udp_client_states_.push_back( - ConnectUdpClientState(stream, encapsulated_client_session, this, flow_id, - target_server_address)); + ConnectUdpClientState(stream, encapsulated_client_session, this, + context_id, target_server_address)); return &connect_udp_client_states_.back(); } void MasqueClientSession::SendPacket( QuicConnectionId client_connection_id, - QuicConnectionId server_connection_id, - absl::string_view packet, + QuicConnectionId server_connection_id, absl::string_view packet, const QuicSocketAddress& target_server_address, EncapsulatedClientSession* encapsulated_client_session) { if (masque_mode_ == MasqueMode::kLegacy) { @@ -138,12 +193,15 @@ void MasqueClientSession::SendPacket( return; } - QuicDatagramFlowId flow_id = connect_udp->flow_id(); - MessageStatus message_status = - SendHttp3Datagram(connect_udp->flow_id(), packet); + MessageStatus message_status = SendHttp3Datagram( + connect_udp->stream()->id(), connect_udp->context_id(), packet); QUIC_DVLOG(1) << "Sent packet to " << target_server_address - << " compressed with flow ID " << flow_id + << " compressed with stream ID " << connect_udp->stream()->id() + << " context ID " + << (connect_udp->context_id().has_value() + ? absl::StrCat(connect_udp->context_id().value()) + : "none") << " and got message status " << MessageStatusToString(message_status); } @@ -179,7 +237,11 @@ void MasqueClientSession::UnregisterConnectionId( for (auto it = connect_udp_client_states_.begin(); it != connect_udp_client_states_.end();) { if (it->encapsulated_client_session() == encapsulated_client_session) { - QUIC_DLOG(INFO) << "Removing state for flow_id " << it->flow_id(); + QUIC_DLOG(INFO) << "Removing state for stream ID " << it->stream()->id() + << " context ID " + << (it->context_id().has_value() + ? absl::StrCat(it->context_id().value()) + : "none"); auto* stream = it->stream(); it = connect_udp_client_states_.erase(it); if (!stream->write_side_closed()) { @@ -192,8 +254,7 @@ void MasqueClientSession::UnregisterConnectionId( } void MasqueClientSession::OnConnectionClosed( - const QuicConnectionCloseFrame& frame, - ConnectionCloseSource source) { + const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) { QuicSpdyClientSession::OnConnectionClosed(frame, source); // Close all encapsulated sessions. for (const auto& client_state : connect_udp_client_states_) { @@ -218,8 +279,10 @@ void MasqueClientSession::OnStreamClosed(QuicStreamId stream_id) { it != connect_udp_client_states_.end();) { if (it->stream()->id() == stream_id) { QUIC_DLOG(INFO) << "Stream " << stream_id - << " was closed, removing state for flow_id " - << it->flow_id(); + << " was closed, removing state for context ID " + << (it->context_id().has_value() + ? absl::StrCat(it->context_id().value()) + : "none"); auto* encapsulated_client_session = it->encapsulated_client_session(); it = connect_udp_client_states_.erase(it); encapsulated_client_session->CloseConnection( @@ -234,24 +297,43 @@ void MasqueClientSession::OnStreamClosed(QuicStreamId stream_id) { QuicSpdyClientSession::OnStreamClosed(stream_id); } +bool MasqueClientSession::OnSettingsFrame(const SettingsFrame& frame) { + QUIC_DLOG(INFO) << "Received SETTINGS: " << frame; + if (!QuicSpdyClientSession::OnSettingsFrame(frame)) { + QUIC_DLOG(ERROR) << "Failed to parse received settings"; + return false; + } + if (!SupportsH3Datagram()) { + QUIC_DLOG(ERROR) << "Refusing to use MASQUE without HTTP/3 Datagrams"; + return false; + } + QUIC_DLOG(INFO) << "Using HTTP Datagram: " << http_datagram_support(); + owner_->OnSettingsReceived(); + return true; +} + MasqueClientSession::ConnectUdpClientState::ConnectUdpClientState( QuicSpdyClientStream* stream, EncapsulatedClientSession* encapsulated_client_session, MasqueClientSession* masque_session, - QuicDatagramFlowId flow_id, + absl::optional context_id, const QuicSocketAddress& target_server_address) : stream_(stream), encapsulated_client_session_(encapsulated_client_session), masque_session_(masque_session), - flow_id_(flow_id), + context_id_(context_id), target_server_address_(target_server_address) { QUICHE_DCHECK_NE(masque_session_, nullptr); - masque_session_->RegisterHttp3FlowId(this->flow_id(), this); + this->stream()->RegisterHttp3DatagramRegistrationVisitor(this); + this->stream()->RegisterHttp3DatagramContextId( + this->context_id(), DatagramFormatType::UDP_PAYLOAD, + /*format_additional_data=*/absl::string_view(), this); } MasqueClientSession::ConnectUdpClientState::~ConnectUdpClientState() { - if (flow_id_.has_value()) { - masque_session_->UnregisterHttp3FlowId(flow_id()); + if (stream() != nullptr) { + stream()->UnregisterHttp3DatagramContextId(context_id()); + stream()->UnregisterHttp3DatagramRegistrationVisitor(); } } @@ -266,23 +348,88 @@ MasqueClientSession::ConnectUdpClientState::operator=( stream_ = other.stream_; encapsulated_client_session_ = other.encapsulated_client_session_; masque_session_ = other.masque_session_; - flow_id_ = other.flow_id_; + context_id_ = other.context_id_; target_server_address_ = other.target_server_address_; - other.flow_id_.reset(); - if (flow_id_.has_value()) { - masque_session_->UnregisterHttp3FlowId(flow_id()); - masque_session_->RegisterHttp3FlowId(flow_id(), this); + other.stream_ = nullptr; + if (stream() != nullptr) { + stream()->MoveHttp3DatagramRegistration(this); + stream()->MoveHttp3DatagramContextIdRegistration(context_id(), this); } return *this; } void MasqueClientSession::ConnectUdpClientState::OnHttp3Datagram( - QuicDatagramFlowId flow_id, + QuicStreamId stream_id, absl::optional context_id, absl::string_view payload) { - QUICHE_DCHECK_EQ(flow_id, this->flow_id()); + QUICHE_DCHECK_EQ(stream_id, stream()->id()); + QUICHE_DCHECK(context_id == context_id_); encapsulated_client_session_->ProcessPacket(payload, target_server_address_); QUIC_DVLOG(1) << "Sent " << payload.size() - << " bytes to connection for flow_id " << flow_id; + << " bytes to connection for stream ID " << stream_id + << " context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) + : "none"); +} + +void MasqueClientSession::ConnectUdpClientState::OnContextReceived( + QuicStreamId stream_id, absl::optional context_id, + DatagramFormatType format_type, absl::string_view format_additional_data) { + if (stream_id != stream_->id()) { + QUIC_BUG(MASQUE client bad datagram context registration) + << "Registered stream ID " << stream_id << ", expected " + << stream_->id(); + return; + } + if (format_type != DatagramFormatType::UDP_PAYLOAD) { + QUIC_DLOG(INFO) << "Ignoring unexpected datagram format type " + << DatagramFormatTypeToString(format_type); + return; + } + if (!format_additional_data.empty()) { + QUIC_DLOG(ERROR) + << "Received non-empty format additional data for context ID " + << (context_id_.has_value() ? context_id_.value() : 0) + << " on stream ID " << stream()->id(); + masque_session_->ResetStream(stream()->id(), QUIC_STREAM_CANCELLED); + return; + } + if (context_id != context_id_) { + QUIC_DLOG(INFO) + << "Ignoring unexpected context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none") + << " instead of " + << (context_id_.has_value() ? absl::StrCat(context_id_.value()) + : "none") + << " on stream ID " << stream_->id(); + return; + } + // Do nothing since the client registers first and we currently ignore + // extensions. +} + +void MasqueClientSession::ConnectUdpClientState::OnContextClosed( + QuicStreamId stream_id, absl::optional context_id, + ContextCloseCode close_code, absl::string_view close_details) { + if (stream_id != stream_->id()) { + QUIC_BUG(MASQUE client bad datagram context registration) + << "Closed context on stream ID " << stream_id << ", expected " + << stream_->id(); + return; + } + if (context_id != context_id_) { + QUIC_DLOG(INFO) + << "Ignoring unexpected close of context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none") + << " instead of " + << (context_id_.has_value() ? absl::StrCat(context_id_.value()) + : "none") + << " on stream ID " << stream_->id(); + return; + } + QUIC_DLOG(INFO) << "Received datagram context close with close code " + << close_code << " close details \"" << close_details + << "\" on stream ID " << stream_->id() << ", closing stream"; + masque_session_->ResetStream(stream_->id(), QUIC_STREAM_CANCELLED); } } // namespace quic diff --git a/gquiche/quic/masque/masque_client_session.h b/gquiche/quic/masque/masque_client_session.h index 95322389..33da2520 100644 --- a/gquiche/quic/masque/masque_client_session.h +++ b/gquiche/quic/masque/masque_client_session.h @@ -5,6 +5,8 @@ #ifndef QUICHE_QUIC_MASQUE_MASQUE_CLIENT_SESSION_H_ #define QUICHE_QUIC_MASQUE_MASQUE_CLIENT_SESSION_H_ +#include + #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "gquiche/quic/core/http/quic_spdy_client_session.h" @@ -31,6 +33,9 @@ class QUIC_NO_EXPORT MasqueClientSession : public QuicSpdyClientSession { // Notifies the owner that the client connection ID is no longer in use. virtual void UnregisterClientConnectionId( QuicConnectionId client_connection_id) = 0; + + // Notifies the owner that a settings frame has been received. + virtual void OnSettingsReceived() = 0; }; // Interface meant to be implemented by encapsulated client sessions, i.e. // the end-to-end QUIC client sessions that run inside MASQUE encapsulation. @@ -53,11 +58,10 @@ class QUIC_NO_EXPORT MasqueClientSession : public QuicSpdyClientSession { // |push_promise_index| or |owner|. All pointers must be non-null. Caller // must ensure that |push_promise_index| and |owner| stay valid for the // lifetime of the newly created MasqueClientSession. - MasqueClientSession(MasqueMode masque_mode, + MasqueClientSession(MasqueMode masque_mode, const std::string& uri_template, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, - QuicConnection* connection, - const QuicServerId& server_id, + QuicConnection* connection, const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, QuicClientPushPromiseIndex* push_promise_index, Owner* owner); @@ -75,6 +79,9 @@ class QUIC_NO_EXPORT MasqueClientSession : public QuicSpdyClientSession { ConnectionCloseSource source) override; void OnStreamClosed(QuicStreamId stream_id) override; + // From QuicSpdySession. + bool OnSettingsFrame(const SettingsFrame& frame) override; + // Send encapsulated packet. void SendPacket(QuicConnectionId client_connection_id, QuicConnectionId server_connection_id, @@ -101,7 +108,8 @@ class QUIC_NO_EXPORT MasqueClientSession : public QuicSpdyClientSession { private: // State that the MasqueClientSession keeps for each CONNECT-UDP request. class QUIC_NO_EXPORT ConnectUdpClientState - : public QuicSpdySession::Http3DatagramVisitor { + : public QuicSpdyStream::Http3DatagramRegistrationVisitor, + public QuicSpdyStream::Http3DatagramVisitor { public: // |stream| and |encapsulated_client_session| must be valid for the lifetime // of the ConnectUdpClientState. @@ -109,7 +117,7 @@ class QUIC_NO_EXPORT MasqueClientSession : public QuicSpdyClientSession { QuicSpdyClientStream* stream, EncapsulatedClientSession* encapsulated_client_session, MasqueClientSession* masque_session, - QuicDatagramFlowId flow_id, + absl::optional context_id, const QuicSocketAddress& target_server_address); ~ConnectUdpClientState(); @@ -124,31 +132,46 @@ class QUIC_NO_EXPORT MasqueClientSession : public QuicSpdyClientSession { EncapsulatedClientSession* encapsulated_client_session() const { return encapsulated_client_session_; } - QuicDatagramFlowId flow_id() const { - QUICHE_DCHECK(flow_id_.has_value()); - return *flow_id_; + absl::optional context_id() const { + return context_id_; } const QuicSocketAddress& target_server_address() const { return target_server_address_; } - // From QuicSpdySession::Http3DatagramVisitor. - void OnHttp3Datagram(QuicDatagramFlowId flow_id, + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::optional context_id, absl::string_view payload) override; + // From QuicSpdyStream::Http3DatagramRegistrationVisitor. + void OnContextReceived(QuicStreamId stream_id, + absl::optional context_id, + DatagramFormatType format_type, + absl::string_view format_additional_data) override; + void OnContextClosed(QuicStreamId stream_id, + absl::optional context_id, + ContextCloseCode close_code, + absl::string_view close_details) override; + private: QuicSpdyClientStream* stream_; // Unowned. EncapsulatedClientSession* encapsulated_client_session_; // Unowned. MasqueClientSession* masque_session_; // Unowned. - absl::optional flow_id_; + absl::optional context_id_; QuicSocketAddress target_server_address_; }; + HttpDatagramSupport LocalHttpDatagramSupport() override { + return HttpDatagramSupport::kDraft00And04; + } + const ConnectUdpClientState* GetOrCreateConnectUdpClientState( const QuicSocketAddress& target_server_address, EncapsulatedClientSession* encapsulated_client_session); MasqueMode masque_mode_; + std::string uri_template_; std::list connect_udp_client_states_; absl::flat_hash_mapperspective() == Perspective::IS_CLIENT ? 0 : 1) {} -QuicDatagramFlowId MasqueCompressionEngine::FindOrCreateCompressionContext( +QuicDatagramStreamId MasqueCompressionEngine::FindOrCreateCompressionContext( QuicConnectionId client_connection_id, QuicConnectionId server_connection_id, - const QuicSocketAddress& server_address, - bool client_connection_id_present, - bool server_connection_id_present, - bool* validated) { - QuicDatagramFlowId flow_id = kFlowId0; + const QuicSocketAddress& server_address, bool client_connection_id_present, + bool server_connection_id_present, bool* validated) { + QuicDatagramStreamId flow_id = kFlowId0; *validated = false; for (const auto& kv : contexts_) { const MasqueCompressionContext& context = kv.second; @@ -74,11 +74,8 @@ QuicDatagramFlowId MasqueCompressionEngine::FindOrCreateCompressionContext( } // Create new compression context. - flow_id = masque_session_->GetNextDatagramFlowId(); - if (flow_id == kFlowId0) { - // Do not use value zero which is reserved in this mode. - flow_id = masque_session_->GetNextDatagramFlowId(); - } + next_available_flow_id_ += 2; + flow_id = next_available_flow_id_; QUIC_DVLOG(1) << "Compression assigning new flow_id " << flow_id << " to " << server_address << " client " << client_connection_id << " server " << server_connection_id; @@ -96,13 +93,9 @@ bool MasqueCompressionEngine::WriteCompressedPacketToSlice( QuicConnectionId server_connection_id, const QuicSocketAddress& server_address, QuicConnectionId destination_connection_id, - QuicConnectionId source_connection_id, - QuicDatagramFlowId flow_id, - bool validated, - uint8_t first_byte, - bool long_header, - QuicDataReader* reader, - QuicDataWriter* writer) { + QuicConnectionId source_connection_id, QuicDatagramStreamId flow_id, + bool validated, uint8_t first_byte, bool long_header, + QuicDataReader* reader, QuicDataWriter* writer) { if (validated) { QUIC_DVLOG(1) << "Compressing using validated flow_id " << flow_id; if (!writer->WriteVarInt62(flow_id)) { @@ -222,8 +215,7 @@ bool MasqueCompressionEngine::WriteCompressedPacketToSlice( } void MasqueCompressionEngine::CompressAndSendPacket( - absl::string_view packet, - QuicConnectionId client_connection_id, + absl::string_view packet, QuicConnectionId client_connection_id, QuicConnectionId server_connection_id, const QuicSocketAddress& server_address) { QUIC_DVLOG(2) << "Compressing client " << client_connection_id << " server " @@ -258,7 +250,7 @@ void MasqueCompressionEngine::CompressAndSendPacket( } bool validated = false; - QuicDatagramFlowId flow_id = FindOrCreateCompressionContext( + QuicDatagramStreamId flow_id = FindOrCreateCompressionContext( client_connection_id, server_connection_id, server_address, client_connection_id_present, server_connection_id_present, &validated); @@ -276,10 +268,10 @@ void MasqueCompressionEngine::CompressAndSendPacket( sizeof(server_address.port()) + sizeof(uint8_t) + server_address.host().ToPackedString().length(); } - QuicUniqueBufferPtr buffer = MakeUniqueBuffer( + QuicBuffer buffer( masque_session_->connection()->helper()->GetStreamSendBufferAllocator(), slice_length); - QuicDataWriter writer(slice_length, buffer.get()); + QuicDataWriter writer(buffer.size(), buffer.data()); if (!WriteCompressedPacketToSlice( client_connection_id, server_connection_id, server_address, @@ -288,18 +280,16 @@ void MasqueCompressionEngine::CompressAndSendPacket( return; } - QuicMemSlice slice(std::move(buffer), slice_length); MessageResult message_result = - masque_session_->SendMessage(QuicMemSliceSpan(&slice)); + masque_session_->SendMessage(QuicMemSlice(std::move(buffer))); QUIC_DVLOG(1) << "Sent packet compressed with flow ID " << flow_id << " and got message result " << message_result; } bool MasqueCompressionEngine::ParseCompressionContext( - QuicDataReader* reader, - MasqueCompressionContext* context) { - QuicDatagramFlowId new_flow_id; + QuicDataReader* reader, MasqueCompressionContext* context) { + QuicDatagramStreamId new_flow_id; if (!reader->ReadVarInt62(&new_flow_id)) { QUIC_DLOG(ERROR) << "Could not read new_flow_id"; return false; @@ -398,10 +388,8 @@ bool MasqueCompressionEngine::ParseCompressionContext( } bool MasqueCompressionEngine::WriteDecompressedPacket( - QuicDataReader* reader, - const MasqueCompressionContext& context, - std::vector* packet, - bool* version_present) { + QuicDataReader* reader, const MasqueCompressionContext& context, + std::vector* packet, bool* version_present) { QuicConnectionId destination_connection_id, source_connection_id; if (masque_session_->perspective() == Perspective::IS_SERVER) { destination_connection_id = context.server_connection_id; @@ -464,16 +452,13 @@ bool MasqueCompressionEngine::WriteDecompressedPacket( } bool MasqueCompressionEngine::DecompressDatagram( - absl::string_view datagram, - QuicConnectionId* client_connection_id, - QuicConnectionId* server_connection_id, - QuicSocketAddress* server_address, - std::vector* packet, - bool* version_present) { + absl::string_view datagram, QuicConnectionId* client_connection_id, + QuicConnectionId* server_connection_id, QuicSocketAddress* server_address, + std::vector* packet, bool* version_present) { QUIC_DVLOG(1) << "Decompressing DATAGRAM frame of length " << datagram.length(); QuicDataReader reader(datagram); - QuicDatagramFlowId flow_id; + QuicDatagramStreamId flow_id; if (!reader.ReadVarInt62(&flow_id)) { QUIC_DLOG(ERROR) << "Could not read flow_id"; return false; @@ -525,14 +510,14 @@ bool MasqueCompressionEngine::DecompressDatagram( void MasqueCompressionEngine::UnregisterClientConnectionId( QuicConnectionId client_connection_id) { - std::vector flow_ids_to_remove; + std::vector flow_ids_to_remove; for (const auto& kv : contexts_) { const MasqueCompressionContext& context = kv.second; if (context.client_connection_id == client_connection_id) { flow_ids_to_remove.push_back(kv.first); } } - for (QuicDatagramFlowId flow_id : flow_ids_to_remove) { + for (QuicDatagramStreamId flow_id : flow_ids_to_remove) { contexts_.erase(flow_id); } } diff --git a/gquiche/quic/masque/masque_compression_engine.h b/gquiche/quic/masque/masque_compression_engine.h index 256c516d..a78052f0 100644 --- a/gquiche/quic/masque/masque_compression_engine.h +++ b/gquiche/quic/masque/masque_compression_engine.h @@ -63,8 +63,7 @@ class QUIC_NO_EXPORT MasqueCompressionEngine { QuicConnectionId* client_connection_id, QuicConnectionId* server_connection_id, QuicSocketAddress* server_address, - std::vector* packet, - bool* version_present); + std::vector* packet, bool* version_present); // Clears all entries referencing |client_connection_id| from the // compression table. @@ -83,12 +82,11 @@ class QUIC_NO_EXPORT MasqueCompressionEngine { // whether the corresponding connection ID is present in the current packet. // |validated| will contain whether the compression context that matches // these arguments is currently validated or not. - QuicDatagramFlowId FindOrCreateCompressionContext( + QuicDatagramStreamId FindOrCreateCompressionContext( QuicConnectionId client_connection_id, QuicConnectionId server_connection_id, const QuicSocketAddress& server_address, - bool client_connection_id_present, - bool server_connection_id_present, + bool client_connection_id_present, bool server_connection_id_present, bool* validated); // Writes compressed packet to |slice| during compression. @@ -97,11 +95,9 @@ class QUIC_NO_EXPORT MasqueCompressionEngine { const QuicSocketAddress& server_address, QuicConnectionId destination_connection_id, QuicConnectionId source_connection_id, - QuicDatagramFlowId flow_id, - bool validated, - uint8_t first_byte, - bool long_header, - QuicDataReader* reader, + QuicDatagramStreamId flow_id, + bool validated, uint8_t first_byte, + bool long_header, QuicDataReader* reader, QuicDataWriter* writer); // Parses compression context from flow ID 0 during decompression. @@ -115,7 +111,8 @@ class QUIC_NO_EXPORT MasqueCompressionEngine { bool* version_present); QuicSpdySession* masque_session_; // Unowned. - absl::flat_hash_map contexts_; + absl::flat_hash_map contexts_; + QuicDatagramStreamId next_available_flow_id_; }; } // namespace quic diff --git a/gquiche/quic/masque/masque_dispatcher.cc b/gquiche/quic/masque/masque_dispatcher.cc index fb6fc195..4ab76a4f 100644 --- a/gquiche/quic/masque/masque_dispatcher.cc +++ b/gquiche/quic/masque/masque_dispatcher.cc @@ -31,12 +31,10 @@ MasqueDispatcher::MasqueDispatcher( masque_server_backend_(masque_server_backend) {} std::unique_ptr MasqueDispatcher::CreateQuicSession( - QuicConnectionId connection_id, - const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view /*alpn*/, + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view /*alpn*/, const ParsedQuicVersion& version, - absl::string_view /*sni*/) { + const ParsedClientHello& /*parsed_chlo*/) { // The MasqueServerSession takes ownership of |connection| below. QuicConnection* connection = new QuicConnection(connection_id, self_address, peer_address, helper(), diff --git a/gquiche/quic/masque/masque_dispatcher.h b/gquiche/quic/masque/masque_dispatcher.h index 385899c7..1c4948e2 100644 --- a/gquiche/quic/masque/masque_dispatcher.h +++ b/gquiche/quic/masque/masque_dispatcher.h @@ -38,12 +38,10 @@ class QUIC_NO_EXPORT MasqueDispatcher : public QuicSimpleDispatcher, // From QuicSimpleDispatcher. std::unique_ptr CreateQuicSession( - QuicConnectionId connection_id, - const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view alpn, + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, const ParsedQuicVersion& version, - absl::string_view sni) override; + const quic::ParsedClientHello& parsed_chlo) override; bool OnFailedToDispatchPacket(const ReceivedPacketInfo& packet_info) override; diff --git a/gquiche/quic/masque/masque_epoll_client.cc b/gquiche/quic/masque/masque_epoll_client.cc index 2fbe1a25..a8f0b512 100644 --- a/gquiche/quic/masque/masque_epoll_client.cc +++ b/gquiche/quic/masque/masque_epoll_client.cc @@ -3,26 +3,25 @@ // found in the LICENSE file. #include "gquiche/quic/masque/masque_epoll_client.h" + +#include + #include "absl/memory/memory.h" #include "gquiche/quic/masque/masque_client_session.h" #include "gquiche/quic/masque/masque_utils.h" +#include "gquiche/quic/tools/quic_url.h" namespace quic { MasqueEpollClient::MasqueEpollClient( - QuicSocketAddress server_address, - const QuicServerId& server_id, - MasqueMode masque_mode, - QuicEpollServer* epoll_server, + QuicSocketAddress server_address, const QuicServerId& server_id, + MasqueMode masque_mode, QuicEpollServer* epoll_server, std::unique_ptr proof_verifier, - const std::string& authority) - : QuicClient(server_address, - server_id, - MasqueSupportedVersions(), - epoll_server, - std::move(proof_verifier)), + const std::string& uri_template) + : QuicClient(server_address, server_id, MasqueSupportedVersions(), + epoll_server, std::move(proof_verifier)), masque_mode_(masque_mode), - authority_(authority) {} + uri_template_(uri_template) {} std::unique_ptr MasqueEpollClient::CreateQuicClientSession( const ParsedQuicVersionVector& supported_versions, @@ -30,8 +29,8 @@ std::unique_ptr MasqueEpollClient::CreateQuicClientSession( QUIC_DLOG(INFO) << "Creating MASQUE session for " << connection->connection_id(); return std::make_unique( - masque_mode_, *config(), supported_versions, connection, server_id(), - crypto_config(), push_promise_index(), this); + masque_mode_, uri_template_, *config(), supported_versions, connection, + server_id(), crypto_config(), push_promise_index(), this); } MasqueClientSession* MasqueEpollClient::masque_client_session() { @@ -42,13 +41,19 @@ QuicConnectionId MasqueEpollClient::connection_id() { return masque_client_session()->connection_id(); } +std::string MasqueEpollClient::authority() const { + QuicUrl url(uri_template_); + return absl::StrCat(url.host(), ":", url.port()); +} + // static std::unique_ptr MasqueEpollClient::Create( - const std::string& host, - int port, - MasqueMode masque_mode, + const std::string& uri_template, MasqueMode masque_mode, QuicEpollServer* epoll_server, std::unique_ptr proof_verifier) { + QuicUrl url(uri_template); + std::string host = url.host(); + uint16_t port = url.port(); // Build the masque_client, and try to connect. QuicSocketAddress addr = tools::LookupAddress(host, absl::StrCat(port)); if (!addr.IsInitialized()) { @@ -59,16 +64,16 @@ std::unique_ptr MasqueEpollClient::Create( // Use absl::WrapUnique(new MasqueEpollClient(...)) instead of // std::make_unique(...) because the constructor for // MasqueEpollClient is private and therefore not accessible from make_unique. - auto masque_client = absl::WrapUnique(new MasqueEpollClient( - addr, server_id, masque_mode, epoll_server, std::move(proof_verifier), - absl::StrCat(host, ":", port))); + auto masque_client = absl::WrapUnique( + new MasqueEpollClient(addr, server_id, masque_mode, epoll_server, + std::move(proof_verifier), uri_template)); if (masque_client == nullptr) { QUIC_LOG(ERROR) << "Failed to create masque_client"; return nullptr; } - masque_client->set_initial_max_packet_length(kDefaultMaxPacketSize); + masque_client->set_initial_max_packet_length(kMasqueMaxOuterPacketSize); masque_client->set_drop_response_body(false); if (!masque_client->Initialize()) { QUIC_LOG(ERROR) << "Failed to initialize masque_client"; @@ -81,12 +86,17 @@ std::unique_ptr MasqueEpollClient::Create( return nullptr; } + if (!masque_client->WaitUntilSettingsReceived()) { + QUIC_LOG(ERROR) << "Failed to receive settings"; + return nullptr; + } + if (masque_client->masque_mode() == MasqueMode::kLegacy) { // Construct the legacy mode init request. spdy::Http2HeaderBlock header_block; header_block[":method"] = "POST"; header_block[":scheme"] = "https"; - header_block[":authority"] = masque_client->authority_; + header_block[":authority"] = masque_client->authority(); header_block[":path"] = "/.well-known/masque/init"; std::string body = "foo"; @@ -114,6 +124,17 @@ std::unique_ptr MasqueEpollClient::Create( return masque_client; } +void MasqueEpollClient::OnSettingsReceived() { + settings_received_ = true; +} + +bool MasqueEpollClient::WaitUntilSettingsReceived() { + while (connected() && !settings_received_) { + network_helper()->RunEventLoop(); + } + return connected() && settings_received_; +} + void MasqueEpollClient::UnregisterClientConnectionId( QuicConnectionId client_connection_id) { std::string body(client_connection_id.data(), client_connection_id.length()); @@ -122,7 +143,7 @@ void MasqueEpollClient::UnregisterClientConnectionId( spdy::Http2HeaderBlock header_block; header_block[":method"] = "POST"; header_block[":scheme"] = "https"; - header_block[":authority"] = authority_; + header_block[":authority"] = authority(); header_block[":path"] = "/.well-known/masque/unregister"; // Make sure to store the response, for later output. diff --git a/gquiche/quic/masque/masque_epoll_client.h b/gquiche/quic/masque/masque_epoll_client.h index 9b93bb80..c06dca20 100644 --- a/gquiche/quic/masque/masque_epoll_client.h +++ b/gquiche/quic/masque/masque_epoll_client.h @@ -5,10 +5,13 @@ #ifndef QUICHE_QUIC_MASQUE_MASQUE_EPOLL_CLIENT_H_ #define QUICHE_QUIC_MASQUE_MASQUE_EPOLL_CLIENT_H_ +#include + #include "gquiche/quic/masque/masque_client_session.h" #include "gquiche/quic/masque/masque_utils.h" #include "gquiche/quic/platform/api/quic_export.h" #include "gquiche/quic/tools/quic_client.h" +#include "gquiche/quic/tools/quic_url.h" namespace quic { @@ -18,9 +21,7 @@ class QUIC_NO_EXPORT MasqueEpollClient : public QuicClient, public: // Constructs a MasqueEpollClient, performs a synchronous DNS lookup. static std::unique_ptr Create( - const std::string& host, - int port, - MasqueMode masque_mode, + const std::string& uri_template, MasqueMode masque_mode, QuicEpollServer* epoll_server, std::unique_ptr proof_verifier); @@ -35,6 +36,8 @@ class QUIC_NO_EXPORT MasqueEpollClient : public QuicClient, // Convenience accessor for the underlying connection ID. QuicConnectionId connection_id(); + // From MasqueClientSession::Owner. + void OnSettingsReceived() override; // Send a MASQUE client connection ID unregister command to the server. void UnregisterClientConnectionId( QuicConnectionId client_connection_id) override; @@ -44,18 +47,24 @@ class QUIC_NO_EXPORT MasqueEpollClient : public QuicClient, private: // Constructor is private, use Create() instead. MasqueEpollClient(QuicSocketAddress server_address, - const QuicServerId& server_id, - MasqueMode masque_mode, + const QuicServerId& server_id, MasqueMode masque_mode, QuicEpollServer* epoll_server, std::unique_ptr proof_verifier, - const std::string& authority); + const std::string& uri_template); + + // Wait synchronously until we receive the peer's settings. Returns whether + // they were received. + bool WaitUntilSettingsReceived(); + + std::string authority() const; // Disallow copy and assign. MasqueEpollClient(const MasqueEpollClient&) = delete; MasqueEpollClient& operator=(const MasqueEpollClient&) = delete; MasqueMode masque_mode_; - std::string authority_; + std::string uri_template_; + bool settings_received_ = false; }; } // namespace quic diff --git a/gquiche/quic/masque/masque_server_backend.cc b/gquiche/quic/masque/masque_server_backend.cc index d9bcea0e..4dcf9b6d 100644 --- a/gquiche/quic/masque/masque_server_backend.cc +++ b/gquiche/quic/masque/masque_server_backend.cc @@ -50,7 +50,9 @@ bool MasqueServerBackend::MaybeHandleMasqueRequest( masque_path = std::string(path.substr(sizeof("/.well-known/masque/") - 1)); } else { QUICHE_DCHECK_EQ(masque_mode_, MasqueMode::kOpen); - if (method != "CONNECT-UDP") { + auto protocol_pair = request_headers.find(":protocol"); + if (method != "CONNECT" || protocol_pair == request_headers.end() || + protocol_pair->second != "connect-udp") { // This is not a MASQUE request. return false; } @@ -90,7 +92,7 @@ bool MasqueServerBackend::MaybeHandleMasqueRequest( QUIC_DLOG(INFO) << "Sending MASQUE response for " << request_headers.DebugString(); - request_handler->OnResponseBackendComplete(response.get(), {}); + request_handler->OnResponseBackendComplete(response.get()); it->second.responses.emplace_back(std::move(response)); return true; diff --git a/gquiche/quic/masque/masque_server_bin.cc b/gquiche/quic/masque/masque_server_bin.cc index 471f63ec..6d966282 100644 --- a/gquiche/quic/masque/masque_server_bin.cc +++ b/gquiche/quic/masque/masque_server_bin.cc @@ -52,8 +52,6 @@ int main(int argc, char* argv[]) { return 0; } - SetQuicReloadableFlag(quic_h3_datagram, true); - quic::MasqueMode masque_mode = quic::MasqueMode::kOpen; std::string mode_string = GetQuicFlag(FLAGS_masque_mode); if (mode_string == "legacy") { diff --git a/gquiche/quic/masque/masque_server_session.cc b/gquiche/quic/masque/masque_server_session.cc index c2917649..339f8ddf 100644 --- a/gquiche/quic/masque/masque_server_session.cc +++ b/gquiche/quic/masque/masque_server_session.cc @@ -6,14 +6,18 @@ #include +#include +#include + #include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "gquiche/quic/core/http/spdy_utils.h" #include "gquiche/quic/core/quic_data_reader.h" #include "gquiche/quic/core/quic_udp_socket.h" #include "gquiche/quic/tools/quic_url.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/platform/api/quiche_url_utils.h" namespace quic { @@ -59,8 +63,7 @@ class FdWrapper { }; std::unique_ptr CreateBackendErrorResponse( - absl::string_view status, - absl::string_view error_details) { + absl::string_view status, absl::string_view error_details) { spdy::Http2HeaderBlock response_headers; response_headers[":status"] = status; response_headers["masque-debug-info"] = error_details; @@ -73,24 +76,15 @@ std::unique_ptr CreateBackendErrorResponse( } // namespace MasqueServerSession::MasqueServerSession( - MasqueMode masque_mode, - const QuicConfig& config, + MasqueMode masque_mode, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, - QuicConnection* connection, - QuicSession::Visitor* visitor, - Visitor* owner, - QuicEpollServer* epoll_server, - QuicCryptoServerStreamBase::Helper* helper, + QuicConnection* connection, QuicSession::Visitor* visitor, Visitor* owner, + QuicEpollServer* epoll_server, QuicCryptoServerStreamBase::Helper* helper, const QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache, MasqueServerBackend* masque_server_backend) - : QuicSimpleServerSession(config, - supported_versions, - connection, - visitor, - helper, - crypto_config, - compressed_certs_cache, + : QuicSimpleServerSession(config, supported_versions, connection, visitor, + helper, crypto_config, compressed_certs_cache, masque_server_backend), masque_server_backend_(masque_server_backend), owner_(owner), @@ -154,8 +148,7 @@ void MasqueServerSession::OnMessageLost(QuicMessageId message_id) { } void MasqueServerSession::OnConnectionClosed( - const QuicConnectionCloseFrame& frame, - ConnectionCloseSource source) { + const QuicConnectionCloseFrame& frame, ConnectionCloseSource source) { QuicSimpleServerSession::OnConnectionClosed(frame, source); QUIC_DLOG(INFO) << "Closing connection for " << connection_id(); masque_server_backend_->RemoveBackendClient(connection_id()); @@ -166,7 +159,7 @@ void MasqueServerSession::OnConnectionClosed( void MasqueServerSession::OnStreamClosed(QuicStreamId stream_id) { connect_udp_server_states_.remove_if( [stream_id](const ConnectUdpServerState& connect_udp) { - return connect_udp.stream_id() == stream_id; + return connect_udp.stream()->id() == stream_id; }); QuicSimpleServerSession::OnStreamClosed(stream_id); @@ -181,6 +174,7 @@ std::unique_ptr MasqueServerSession::HandleMasqueRequest( auto path_pair = request_headers.find(":path"); auto scheme_pair = request_headers.find(":scheme"); auto method_pair = request_headers.find(":method"); + auto protocol_pair = request_headers.find(":protocol"); auto authority_pair = request_headers.find(":authority"); if (path_pair == request_headers.end()) { QUIC_DLOG(ERROR) << "MASQUE request is missing :path"; @@ -194,6 +188,10 @@ std::unique_ptr MasqueServerSession::HandleMasqueRequest( QUIC_DLOG(ERROR) << "MASQUE request is missing :method"; return CreateBackendErrorResponse("400", "Missing :method"); } + if (protocol_pair == request_headers.end()) { + QUIC_DLOG(ERROR) << "MASQUE request is missing :protocol"; + return CreateBackendErrorResponse("400", "Missing :protocol"); + } if (authority_pair == request_headers.end()) { QUIC_DLOG(ERROR) << "MASQUE request is missing :authority"; return CreateBackendErrorResponse("400", "Missing :authority"); @@ -201,6 +199,7 @@ std::unique_ptr MasqueServerSession::HandleMasqueRequest( absl::string_view path = path_pair->second; absl::string_view scheme = scheme_pair->second; absl::string_view method = method_pair->second; + absl::string_view protocol = protocol_pair->second; absl::string_view authority = authority_pair->second; if (path.empty()) { QUIC_DLOG(ERROR) << "MASQUE request with empty path"; @@ -208,34 +207,52 @@ std::unique_ptr MasqueServerSession::HandleMasqueRequest( } if (scheme.empty()) { return CreateBackendErrorResponse("400", "Empty scheme"); - return nullptr; } - if (method != "CONNECT-UDP") { + if (method != "CONNECT") { QUIC_DLOG(ERROR) << "MASQUE request with bad method \"" << method << "\""; return CreateBackendErrorResponse("400", "Bad method"); } - absl::optional flow_id = - SpdyUtils::ParseDatagramFlowIdHeader(request_headers); - if (!flow_id.has_value()) { - QUIC_DLOG(ERROR) - << "MASQUE request with bad or missing DatagramFlowId header"; - return CreateBackendErrorResponse("400", - "Bad or missing DatagramFlowId header"); - } - QuicUrl url(absl::StrCat("https://", authority)); - if (!url.IsValid() || url.PathParamsQuery() != "/") { - QUIC_DLOG(ERROR) << "MASQUE request with bad authority \"" << authority + if (protocol != "connect-udp") { + QUIC_DLOG(ERROR) << "MASQUE request with bad protocol \"" << protocol << "\""; - return CreateBackendErrorResponse("400", "Bad authority"); + return CreateBackendErrorResponse("400", "Bad protocol"); + } + absl::optional flow_id; + if (http_datagram_support() == HttpDatagramSupport::kDraft00) { + flow_id = SpdyUtils::ParseDatagramFlowIdHeader(request_headers); + if (!flow_id.has_value()) { + QUIC_DLOG(ERROR) + << "MASQUE request with bad or missing DatagramFlowId header"; + return CreateBackendErrorResponse( + "400", "Bad or missing DatagramFlowId header"); + } + } + // Extract target host and port from path using default template. + std::vector path_split = absl::StrSplit(path, '/'); + if (path_split.size() != 4 || !path_split[0].empty() || + path_split[1].empty() || path_split[2].empty() || + !path_split[3].empty()) { + QUIC_DLOG(ERROR) << "MASQUE request with bad path \"" << path << "\""; + return CreateBackendErrorResponse("400", "Bad path"); + } + absl::optional host = quiche::AsciiUrlDecode(path_split[1]); + if (!host.has_value()) { + QUIC_DLOG(ERROR) << "Failed to decode host \"" << path_split[1] << "\""; + return CreateBackendErrorResponse("500", "Failed to decode host"); + } + absl::optional port = quiche::AsciiUrlDecode(path_split[2]); + if (!port.has_value()) { + QUIC_DLOG(ERROR) << "Failed to decode port \"" << path_split[2] << "\""; + return CreateBackendErrorResponse("500", "Failed to decode port"); } - std::string port = absl::StrCat(url.port()); + // Perform DNS resolution. addrinfo hint = {}; hint.ai_protocol = IPPROTO_UDP; addrinfo* info_list = nullptr; - int result = - getaddrinfo(url.host().c_str(), port.c_str(), &hint, &info_list); + int result = getaddrinfo(host.value().c_str(), port.value().c_str(), &hint, + &info_list); if (result != 0) { QUIC_DLOG(ERROR) << "Failed to resolve " << authority << ": " << gai_strerror(result); @@ -247,7 +264,9 @@ std::unique_ptr MasqueServerSession::HandleMasqueRequest( info_list, freeaddrinfo); QuicSocketAddress target_server_address(info_list->ai_addr, info_list->ai_addrlen); - QUIC_DLOG(INFO) << "Got CONNECT_UDP request flow_id=" << *flow_id + QUIC_DLOG(INFO) << "Got CONNECT_UDP request on stream ID " + << request_handler->stream_id() << " flow_id=" + << (flow_id.has_value() ? absl::StrCat(*flow_id) : "none") << " target_server_address=\"" << target_server_address << "\""; @@ -256,21 +275,47 @@ std::unique_ptr MasqueServerSession::HandleMasqueRequest( QUIC_DLOG(ERROR) << "Socket creation failed"; return CreateBackendErrorResponse("500", "Socket creation failed"); } - QuicSocketAddress any_v6_address(QuicIpAddress::Any6(), 0); + QuicSocketAddress empty_address(QuicIpAddress::Any6(), 0); + if (target_server_address.host().IsIPv4()) { + empty_address = QuicSocketAddress(QuicIpAddress::Any4(), 0); + } QuicUdpSocketApi socket_api; - if (!socket_api.Bind(fd_wrapper.fd(), any_v6_address)) { + if (!socket_api.Bind(fd_wrapper.fd(), empty_address)) { QUIC_DLOG(ERROR) << "Socket bind failed"; return CreateBackendErrorResponse("500", "Socket bind failed"); } epoll_server_->RegisterFDForRead(fd_wrapper.fd(), this); - connect_udp_server_states_.emplace_back(ConnectUdpServerState( - *flow_id, request_handler->stream_id(), target_server_address, - fd_wrapper.extract_fd(), this)); + absl::optional context_id; + QuicSpdyStream* stream = static_cast( + GetActiveStream(request_handler->stream_id())); + if (stream == nullptr) { + QUIC_BUG(bad masque server stream type) + << "Unexpected stream type for stream ID " + << request_handler->stream_id(); + return CreateBackendErrorResponse("500", "Bad stream type"); + } + if (flow_id.has_value()) { + stream->RegisterHttp3DatagramFlowId(*flow_id); + } + connect_udp_server_states_.push_back( + ConnectUdpServerState(stream, context_id, target_server_address, + fd_wrapper.extract_fd(), this)); + + if (http_datagram_support() == HttpDatagramSupport::kDraft00) { + // TODO(b/181256914) remove this when we drop support for + // draft-ietf-masque-h3-datagram-00 in favor of later drafts. + stream->RegisterHttp3DatagramContextId( + context_id, DatagramFormatType::UDP_PAYLOAD, + /*format_additional_data=*/absl::string_view(), + &connect_udp_server_states_.back()); + } spdy::Http2HeaderBlock response_headers; response_headers[":status"] = "200"; - SpdyUtils::AddDatagramFlowIdHeader(&response_headers, *flow_id); + if (flow_id.has_value()) { + SpdyUtils::AddDatagramFlowIdHeader(&response_headers, *flow_id); + } auto response = std::make_unique(); response->set_response_type(QuicBackendResponse::INCOMPLETE_RESPONSE); response->set_headers(std::move(response_headers)); @@ -327,8 +372,7 @@ void MasqueServerSession::HandlePacketFromServer( } void MasqueServerSession::OnRegistration(QuicEpollServer* /*eps*/, - QuicUdpSocketFd fd, - int event_mask) { + QuicUdpSocketFd fd, int event_mask) { QUIC_DVLOG(1) << "OnRegistration " << fd << " event_mask " << event_mask; } @@ -351,13 +395,12 @@ void MasqueServerSession::OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) { << event->in_events << " on unknown fd " << fd; return; } - QuicDatagramFlowId flow_id = it->flow_id(); QuicSocketAddress expected_target_server_address = it->target_server_address(); QUICHE_DCHECK(expected_target_server_address.IsInitialized()); QUIC_DVLOG(1) << "Received readable event on fd " << fd << " (mask " - << event->in_events << ") flow_id " << flow_id << " server " - << expected_target_server_address; + << event->in_events << ") stream ID " << it->stream()->id() + << " server " << expected_target_server_address; QuicUdpSocketApi socket_api; BitMask64 packet_info_interested(QuicUdpPacketInfoBit::PEER_ADDRESS); char packet_buffer[kMaxIncomingPacketSize]; @@ -393,12 +436,14 @@ void MasqueServerSession::OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) { return; } // The packet is valid, send it to the client in a DATAGRAM frame. - MessageStatus message_status = SendHttp3Datagram( - flow_id, absl::string_view(read_result.packet_buffer.buffer, - read_result.packet_buffer.buffer_len)); + MessageStatus message_status = it->stream()->SendHttp3Datagram( + it->context_id(), + absl::string_view(read_result.packet_buffer.buffer, + read_result.packet_buffer.buffer_len)); QUIC_DVLOG(1) << "Sent UDP packet from " << expected_target_server_address << " of length " << read_result.packet_buffer.buffer_len - << " with flow ID " << flow_id << " and got message status " + << " with stream ID " << it->stream()->id() + << " and got message status " << MessageStatusToString(message_status); } } @@ -417,25 +462,39 @@ std::string MasqueServerSession::Name() const { return std::string("MasqueServerSession-") + connection_id().ToString(); } +bool MasqueServerSession::OnSettingsFrame(const SettingsFrame& frame) { + QUIC_DLOG(INFO) << "Received SETTINGS: " << frame; + if (!QuicSimpleServerSession::OnSettingsFrame(frame)) { + return false; + } + if (!SupportsH3Datagram()) { + QUIC_DLOG(ERROR) << "Refusing to use MASQUE without HTTP Datagrams"; + return false; + } + QUIC_DLOG(INFO) << "Using HTTP Datagram: " << http_datagram_support(); + return true; +} + MasqueServerSession::ConnectUdpServerState::ConnectUdpServerState( - QuicDatagramFlowId flow_id, - QuicStreamId stream_id, - const QuicSocketAddress& target_server_address, - QuicUdpSocketFd fd, + QuicSpdyStream* stream, absl::optional context_id, + const QuicSocketAddress& target_server_address, QuicUdpSocketFd fd, MasqueServerSession* masque_session) - : flow_id_(flow_id), - stream_id_(stream_id), + : stream_(stream), + context_id_(context_id), target_server_address_(target_server_address), fd_(fd), masque_session_(masque_session) { QUICHE_DCHECK_NE(fd_, kQuicInvalidSocketFd); QUICHE_DCHECK_NE(masque_session_, nullptr); - masque_session_->RegisterHttp3FlowId(this->flow_id(), this); + this->stream()->RegisterHttp3DatagramRegistrationVisitor(this); } MasqueServerSession::ConnectUdpServerState::~ConnectUdpServerState() { - if (flow_id_.has_value()) { - masque_session_->UnregisterHttp3FlowId(flow_id()); + if (stream() != nullptr) { + if (context_registered_) { + stream()->UnregisterHttp3DatagramContextId(context_id()); + } + stream()->UnregisterHttp3DatagramRegistrationVisitor(); } if (fd_ == kQuicInvalidSocketFd) { return; @@ -461,24 +520,29 @@ MasqueServerSession::ConnectUdpServerState::operator=( masque_session_->epoll_server()->UnregisterFD(fd_); socket_api.Destroy(fd_); } - flow_id_ = other.flow_id_; - stream_id_ = other.stream_id_; + stream_ = other.stream_; + other.stream_ = nullptr; + context_id_ = other.context_id_; target_server_address_ = other.target_server_address_; fd_ = other.fd_; masque_session_ = other.masque_session_; other.fd_ = kQuicInvalidSocketFd; - other.flow_id_.reset(); - if (flow_id_.has_value()) { - masque_session_->UnregisterHttp3FlowId(flow_id()); - masque_session_->RegisterHttp3FlowId(flow_id(), this); + context_registered_ = other.context_registered_; + other.context_registered_ = false; + if (stream() != nullptr) { + stream()->MoveHttp3DatagramRegistration(this); + if (context_registered_) { + stream()->MoveHttp3DatagramContextIdRegistration(context_id(), this); + } } return *this; } void MasqueServerSession::ConnectUdpServerState::OnHttp3Datagram( - QuicDatagramFlowId flow_id, + QuicStreamId stream_id, absl::optional context_id, absl::string_view payload) { - QUICHE_DCHECK_EQ(flow_id, this->flow_id()); + QUICHE_DCHECK_EQ(stream_id, stream()->id()); + QUICHE_DCHECK(context_id == context_id_); QuicUdpSocketApi socket_api; QuicUdpPacketInfo packet_info; packet_info.SetPeerAddress(target_server_address_); @@ -488,4 +552,77 @@ void MasqueServerSession::ConnectUdpServerState::OnHttp3Datagram( << target_server_address_ << " with result " << write_result; } +void MasqueServerSession::ConnectUdpServerState::OnContextReceived( + QuicStreamId stream_id, absl::optional context_id, + DatagramFormatType format_type, absl::string_view format_additional_data) { + if (stream_id != stream()->id()) { + QUIC_BUG(MASQUE server bad datagram context registration) + << "Registered stream ID " << stream_id << ", expected " + << stream()->id(); + return; + } + if (format_type != DatagramFormatType::UDP_PAYLOAD) { + QUIC_DLOG(INFO) << "Ignoring unexpected datagram format type " + << DatagramFormatTypeToString(format_type); + return; + } + if (!format_additional_data.empty()) { + QUIC_DLOG(ERROR) + << "Received non-empty format additional data for context ID " + << (context_id_.has_value() ? context_id_.value() : 0) + << " on stream ID " << stream()->id(); + masque_session_->ResetStream(stream()->id(), QUIC_STREAM_CANCELLED); + return; + } + if (!context_received_) { + context_received_ = true; + context_id_ = context_id; + } + if (context_id != context_id_) { + QUIC_DLOG(INFO) + << "Ignoring unexpected context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none") + << " instead of " + << (context_id_.has_value() ? absl::StrCat(context_id_.value()) + : "none") + << " on stream ID " << stream()->id(); + return; + } + if (context_registered_) { + QUIC_BUG(MASQUE server double datagram context registration) + << "Try to re-register stream ID " << stream_id << " context ID " + << (context_id_.has_value() ? absl::StrCat(context_id_.value()) + : "none"); + return; + } + context_registered_ = true; + stream()->RegisterHttp3DatagramContextId(context_id_, format_type, + format_additional_data, this); +} + +void MasqueServerSession::ConnectUdpServerState::OnContextClosed( + QuicStreamId stream_id, absl::optional context_id, + ContextCloseCode close_code, absl::string_view close_details) { + if (stream_id != stream()->id()) { + QUIC_BUG(MASQUE server bad datagram context registration) + << "Closed context on stream ID " << stream_id << ", expected " + << stream()->id(); + return; + } + if (context_id != context_id_) { + QUIC_DLOG(INFO) + << "Ignoring unexpected close of context ID " + << (context_id.has_value() ? absl::StrCat(context_id.value()) : "none") + << " instead of " + << (context_id_.has_value() ? absl::StrCat(context_id_.value()) + : "none") + << " on stream ID " << stream()->id(); + return; + } + QUIC_DLOG(INFO) << "Received datagram context close with close code " + << close_code << " close details \"" << close_details + << "\" on stream ID " << stream()->id() << ", closing stream"; + masque_session_->ResetStream(stream()->id(), QUIC_STREAM_CANCELLED); +} + } // namespace quic diff --git a/gquiche/quic/masque/masque_server_session.h b/gquiche/quic/masque/masque_server_session.h index 82a4020a..e4628736 100644 --- a/gquiche/quic/masque/masque_server_session.h +++ b/gquiche/quic/masque/masque_server_session.h @@ -38,14 +38,10 @@ class QUIC_NO_EXPORT MasqueServerSession }; explicit MasqueServerSession( - MasqueMode masque_mode, - const QuicConfig& config, + MasqueMode masque_mode, const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, - QuicConnection* connection, - QuicSession::Visitor* visitor, - Visitor* owner, - QuicEpollServer* epoll_server, - QuicCryptoServerStreamBase::Helper* helper, + QuicConnection* connection, QuicSession::Visitor* visitor, Visitor* owner, + QuicEpollServer* epoll_server, QuicCryptoServerStreamBase::Helper* helper, const QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache, MasqueServerBackend* masque_server_backend); @@ -71,8 +67,7 @@ class QUIC_NO_EXPORT MasqueServerSession QuicSimpleServerBackend::RequestHandler* request_handler) override; // From QuicEpollCallbackInterface. - void OnRegistration(QuicEpollServer* eps, - QuicUdpSocketFd fd, + void OnRegistration(QuicEpollServer* eps, QuicUdpSocketFd fd, int event_mask) override; void OnModification(QuicUdpSocketFd fd, int event_mask) override; void OnEvent(QuicUdpSocketFd fd, QuicEpollEvent* event) override; @@ -88,15 +83,15 @@ class QUIC_NO_EXPORT MasqueServerSession private: // State that the MasqueServerSession keeps for each CONNECT-UDP request. class QUIC_NO_EXPORT ConnectUdpServerState - : public QuicSpdySession::Http3DatagramVisitor { + : public QuicSpdyStream::Http3DatagramRegistrationVisitor, + public QuicSpdyStream::Http3DatagramVisitor { public: // ConnectUdpServerState takes ownership of |fd|. It will unregister it // from |epoll_server| and close the file descriptor when destructed. explicit ConnectUdpServerState( - QuicDatagramFlowId flow_id, - QuicStreamId stream_id, - const QuicSocketAddress& target_server_address, - QuicUdpSocketFd fd, + QuicSpdyStream* stream, + absl::optional context_id, + const QuicSocketAddress& target_server_address, QuicUdpSocketFd fd, MasqueServerSession* masque_session); ~ConnectUdpServerState(); @@ -107,28 +102,46 @@ class QUIC_NO_EXPORT MasqueServerSession ConnectUdpServerState& operator=(const ConnectUdpServerState&) = delete; ConnectUdpServerState& operator=(ConnectUdpServerState&&); - QuicDatagramFlowId flow_id() const { - QUICHE_DCHECK(flow_id_.has_value()); - return *flow_id_; + QuicSpdyStream* stream() const { return stream_; } + absl::optional context_id() const { + return context_id_; } - QuicStreamId stream_id() const { return stream_id_; } const QuicSocketAddress& target_server_address() const { return target_server_address_; } QuicUdpSocketFd fd() const { return fd_; } - // From QuicSpdySession::Http3DatagramVisitor. - void OnHttp3Datagram(QuicDatagramFlowId flow_id, + // From QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::optional context_id, absl::string_view payload) override; + // From QuicSpdyStream::Http3DatagramRegistrationVisitor. + void OnContextReceived(QuicStreamId stream_id, + absl::optional context_id, + DatagramFormatType format_type, + absl::string_view format_additional_data) override; + void OnContextClosed(QuicStreamId stream_id, + absl::optional context_id, + ContextCloseCode close_code, + absl::string_view close_details) override; + private: - absl::optional flow_id_; - QuicStreamId stream_id_; + QuicSpdyStream* stream_; + absl::optional context_id_; QuicSocketAddress target_server_address_; - QuicUdpSocketFd fd_; // Owned. + QuicUdpSocketFd fd_; // Owned. MasqueServerSession* masque_session_; // Unowned. + bool context_received_ = false; + bool context_registered_ = false; }; + // From QuicSpdySession. + bool OnSettingsFrame(const SettingsFrame& frame) override; + HttpDatagramSupport LocalHttpDatagramSupport() override { + return HttpDatagramSupport::kDraft00And04; + } + MasqueServerBackend* masque_server_backend_; // Unowned. Visitor* owner_; // Unowned. QuicEpollServer* epoll_server_; // Unowned. diff --git a/gquiche/quic/masque/masque_utils.h b/gquiche/quic/masque/masque_utils.h index 02ca2bc2..18de7f66 100644 --- a/gquiche/quic/masque/masque_utils.h +++ b/gquiche/quic/masque/masque_utils.h @@ -18,7 +18,10 @@ QUIC_NO_EXPORT ParsedQuicVersionVector MasqueSupportedVersions(); QUIC_NO_EXPORT QuicConfig MasqueEncapsulatedConfig(); // Maximum packet size for encapsulated connections. -enum : QuicByteCount { kMasqueMaxEncapsulatedPacketSize = 1300 }; +enum : QuicByteCount { + kMasqueMaxEncapsulatedPacketSize = 1300, + kMasqueMaxOuterPacketSize = 1350, +}; // Mode that MASQUE is operating in. enum class MasqueMode : uint8_t { diff --git a/gquiche/quic/platform/api/quic_bug_tracker.h b/gquiche/quic/platform/api/quic_bug_tracker.h index 89e65909..49bc91bc 100644 --- a/gquiche/quic/platform/api/quic_bug_tracker.h +++ b/gquiche/quic/platform/api/quic_bug_tracker.h @@ -5,11 +5,11 @@ #ifndef QUICHE_QUIC_PLATFORM_API_QUIC_BUG_TRACKER_H_ #define QUICHE_QUIC_PLATFORM_API_QUIC_BUG_TRACKER_H_ -#include "platform/quic_platform_impl/quic_bug_tracker_impl.h" +#include "gquiche/common/platform/api/quiche_bug_tracker.h" -#define QUIC_BUG(x) QUICHE_BUG_IMPL(x) -#define QUIC_BUG_IF(x,y) QUICHE_BUG_IF_IMPL(x,y) -#define QUIC_PEER_BUG(x) QUICHE_PEER_BUG_IMPL(x) -#define QUIC_PEER_BUG_IF(x,y) QUICHE_PEER_BUG_IF_IMPL(x,y) +#define QUIC_BUG QUICHE_BUG +#define QUIC_BUG_IF QUICHE_BUG_IF +#define QUIC_PEER_BUG QUICHE_PEER_BUG +#define QUIC_PEER_BUG_IF QUICHE_PEER_BUG_IF #endif // QUICHE_QUIC_PLATFORM_API_QUIC_BUG_TRACKER_H_ diff --git a/gquiche/quic/platform/api/quic_containers.h b/gquiche/quic/platform/api/quic_containers.h index d4363cd5..335a8e66 100644 --- a/gquiche/quic/platform/api/quic_containers.h +++ b/gquiche/quic/platform/api/quic_containers.h @@ -5,60 +5,17 @@ #ifndef QUICHE_QUIC_PLATFORM_API_QUIC_CONTAINERS_H_ #define QUICHE_QUIC_PLATFORM_API_QUIC_CONTAINERS_H_ -#include "platform/quic_platform_impl/quic_containers_impl.h" +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_containers_impl.h" namespace quic { -// The default hasher used by hash tables. -template -using QuicDefaultHasher = QuicDefaultHasherImpl; - -// A general-purpose unordered map. -template > -using QuicUnorderedMap = QuicUnorderedMapImpl; - -// A general-purpose unordered map that does not gurantee pointer stability. -template > -using QuicHashMap = QuicHashMapImpl; - -// A general-purpose unordered set. -template > -using QuicUnorderedSet = QuicUnorderedSetImpl; - -// A general-purpose unordered set that does not gurantee pointer stability. -template > -using QuicHashSet = QuicHashSetImpl; - -// A map which offers insertion-ordered iteration. -template > -using QuicLinkedHashMap = QuicLinkedHashMapImpl; - -// Used for maps that are typically small, then it is faster than (for example) -// hash_map which is optimized for large data sets. QuicSmallMap upgrades itself -// automatically to a QuicSmallMapImpl-specified map when it runs out of space. -// -// DOES NOT GUARANTEE POINTER OR ITERATOR STABILITY! -template -using QuicSmallMap = QuicSmallMapImpl; - -// Represents a simple queue which may be backed by a list or -// a flat circular buffer. -// -// DOES NOT GUARANTEE POINTER OR ITERATOR STABILITY! -template -using QuicQueue = QuicQueueImpl; - -// A vector optimized for small sizes. Provides the same APIs as a std::vector. -template > -using QuicInlinedVector = QuicInlinedVectorImpl; - -// An ordered set of values. +// An ordered container optimized for small sets. +// An implementation with O(n) mutations might be chosen +// in case it has better memory usage and/or faster access. // // DOES NOT GUARANTEE POINTER OR ITERATOR STABILITY! -template , - typename Rep = std::vector> -using QuicOrderedSet = QuicOrderedSetImpl; +template > +using QuicSmallOrderedSet = ::quiche::QuicheSmallOrderedSetImpl; } // namespace quic diff --git a/gquiche/quic/platform/api/quic_epoll_test_tools.h b/gquiche/quic/platform/api/quic_epoll_test_tools.h index 5a7a9bb1..b968e71e 100644 --- a/gquiche/quic/platform/api/quic_epoll_test_tools.h +++ b/gquiche/quic/platform/api/quic_epoll_test_tools.h @@ -5,7 +5,7 @@ #ifndef QUICHE_QUIC_PLATFORM_API_QUIC_EPOLL_TEST_TOOLS_H_ #define QUICHE_QUIC_PLATFORM_API_QUIC_EPOLL_TEST_TOOLS_H_ -#include "platform/quic_epoll_test_tools_impl.h" +#include "platform/quic_platform_impl/quic_epoll_test_tools_impl.h" using QuicFakeEpollServer = QuicFakeEpollServerImpl; diff --git a/gquiche/quic/platform/api/quic_export.h b/gquiche/quic/platform/api/quic_export.h index 0317cf0b..5dc99e46 100644 --- a/gquiche/quic/platform/api/quic_export.h +++ b/gquiche/quic/platform/api/quic_export.h @@ -5,7 +5,7 @@ #ifndef QUICHE_QUIC_PLATFORM_API_QUIC_EXPORT_H_ #define QUICHE_QUIC_PLATFORM_API_QUIC_EXPORT_H_ -#include "platform/quiche_platform_impl/quiche_export_impl.h" +#include "gquiche/common/platform/default/quiche_platform_impl/quiche_export_impl.h" // QUIC_EXPORT is not meant to be used. #define QUIC_EXPORT QUICHE_EXPORT_IMPL diff --git a/gquiche/quic/platform/api/quic_hostname_utils_test.cc b/gquiche/quic/platform/api/quic_hostname_utils_test.cc index 88d839c5..866b41d3 100644 --- a/gquiche/quic/platform/api/quic_hostname_utils_test.cc +++ b/gquiche/quic/platform/api/quic_hostname_utils_test.cc @@ -19,10 +19,7 @@ TEST_F(QuicHostnameUtilsTest, IsValidSNI) { // IP as SNI. EXPECT_FALSE(QuicHostnameUtils::IsValidSNI("192.168.0.1")); // SNI without any dot. - SetQuicReloadableFlag(quic_and_tls_allow_sni_without_dots, true); EXPECT_TRUE(QuicHostnameUtils::IsValidSNI("somedomain")); - SetQuicReloadableFlag(quic_and_tls_allow_sni_without_dots, false); - EXPECT_FALSE(QuicHostnameUtils::IsValidSNI("somedomain")); // Invalid by RFC2396 but unfortunately domains of this form exist. EXPECT_TRUE(QuicHostnameUtils::IsValidSNI("some_domain.com")); // An empty string must be invalid otherwise the QUIC client will try sending diff --git a/gquiche/quic/platform/api/quic_mem_slice.h b/gquiche/quic/platform/api/quic_mem_slice.h index 312a6f42..3cd3ccb5 100644 --- a/gquiche/quic/platform/api/quic_mem_slice.h +++ b/gquiche/quic/platform/api/quic_mem_slice.h @@ -6,6 +6,8 @@ #define QUICHE_QUIC_PLATFORM_API_QUIC_MEM_SLICE_H_ #include + +#include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_export.h" #include "platform/quic_platform_impl/quic_mem_slice_impl.h" @@ -31,9 +33,20 @@ class QUIC_EXPORT_PRIVATE QuicMemSlice { // Constructs a QuicMemSlice that takes ownership of |buffer|. |length| must // not be zero. To construct an empty QuicMemSlice, use the zero-argument // constructor instead. + // TODO(vasilvv): switch all users to QuicBuffer version, and make this + // private. QuicMemSlice(QuicUniqueBufferPtr buffer, size_t length) : impl_(std::move(buffer), length) {} + // Constructs a QuicMemSlice that takes ownership of |buffer|. The length of + // the |buffer| must not be zero. To construct an empty QuicMemSlice, use the + // zero-argument constructor instead. + explicit QuicMemSlice(QuicBuffer buffer) : QuicMemSlice() { + // Store the size of the buffer *before* calling buffer.Release(). + const size_t size = buffer.size(); + *this = QuicMemSlice(buffer.Release(), size); + } + // Constructs a QuicMemSlice that takes ownership of |buffer| allocated on // heap. |length| must not be zero. QuicMemSlice(std::unique_ptr buffer, size_t length) @@ -61,6 +74,10 @@ class QUIC_EXPORT_PRIVATE QuicMemSlice { const char* data() const { return impl_.data(); } // Returns the length of underlying data buffer. size_t length() const { return impl_.length(); } + // Returns the representation of the underlying data as a string view. + absl::string_view AsStringView() const { + return absl::string_view(data(), length()); + } bool empty() const { return impl_.empty(); } diff --git a/gquiche/quic/platform/api/quic_mem_slice_storage.cc b/gquiche/quic/platform/api/quic_mem_slice_storage.cc new file mode 100644 index 00000000..d2e5204f --- /dev/null +++ b/gquiche/quic/platform/api/quic_mem_slice_storage.cc @@ -0,0 +1,35 @@ +// Copyright 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/platform/api/quic_mem_slice_storage.h" + +#include "gquiche/quic/core/quic_utils.h" + +namespace quic { + +QuicMemSliceStorage::QuicMemSliceStorage(const struct iovec* iov, int iov_count, + QuicBufferAllocator* allocator, + const QuicByteCount max_slice_len) { + if (iov == nullptr) { + return; + } + QuicByteCount write_len = 0; + for (int i = 0; i < iov_count; ++i) { + write_len += iov[i].iov_len; + } + QUICHE_DCHECK_LT(0u, write_len); + + size_t io_offset = 0; + while (write_len > 0) { + size_t slice_len = std::min(write_len, max_slice_len); + QuicBuffer buffer(allocator, slice_len); + QuicUtils::CopyToBuffer(iov, iov_count, io_offset, slice_len, + buffer.data()); + storage_.push_back(QuicMemSlice(std::move(buffer))); + write_len -= slice_len; + io_offset += slice_len; + } +} + +} // namespace quic diff --git a/gquiche/quic/platform/api/quic_mem_slice_storage.h b/gquiche/quic/platform/api/quic_mem_slice_storage.h index 24b95e75..94039d5e 100644 --- a/gquiche/quic/platform/api/quic_mem_slice_storage.h +++ b/gquiche/quic/platform/api/quic_mem_slice_storage.h @@ -5,8 +5,14 @@ #ifndef QUICHE_QUIC_PLATFORM_API_QUIC_MEM_SLICE_STORAGE_H_ #define QUICHE_QUIC_PLATFORM_API_QUIC_MEM_SLICE_STORAGE_H_ +#include + +#include "absl/types/span.h" +#include "gquiche/quic/core/quic_buffer_allocator.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/platform/api/quic_export.h" -#include "platform/quic_platform_impl/quic_mem_slice_storage_impl.h" +#include "gquiche/quic/platform/api/quic_iovec.h" +#include "gquiche/quic/platform/api/quic_mem_slice.h" namespace quic { @@ -14,11 +20,9 @@ namespace quic { // use cases such as turning into QuicMemSliceSpan. class QUIC_EXPORT_PRIVATE QuicMemSliceStorage { public: - QuicMemSliceStorage(const struct iovec* iov, - int iov_count, + QuicMemSliceStorage(const struct iovec* iov, int iov_count, QuicBufferAllocator* allocator, - const QuicByteCount max_slice_len) - : impl_(iov, iov_count, allocator, max_slice_len) {} + const QuicByteCount max_slice_len); QuicMemSliceStorage(const QuicMemSliceStorage& other) = default; QuicMemSliceStorage& operator=(const QuicMemSliceStorage& other) = default; @@ -28,12 +32,10 @@ class QUIC_EXPORT_PRIVATE QuicMemSliceStorage { ~QuicMemSliceStorage() = default; // Return a QuicMemSliceSpan form of the storage. - QuicMemSliceSpan ToSpan() { return impl_.ToSpan(); } - - void Append(QuicMemSlice slice) { impl_.Append(std::move(*slice.impl())); } + absl::Span ToSpan() { return absl::MakeSpan(storage_); } private: - QuicMemSliceStorageImpl impl_; + std::vector storage_; }; } // namespace quic diff --git a/gquiche/quic/platform/api/quic_mem_slice_storage_test.cc b/gquiche/quic/platform/api/quic_mem_slice_storage_test.cc index f42f3faa..f6d087c5 100644 --- a/gquiche/quic/platform/api/quic_mem_slice_storage_test.cc +++ b/gquiche/quic/platform/api/quic_mem_slice_storage_test.cc @@ -6,7 +6,6 @@ #include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/quic/platform/api/quic_test_mem_slice_vector.h" namespace quic { namespace test { @@ -28,8 +27,8 @@ TEST_F(QuicMemSliceStorageImplTest, SingleIov) { struct iovec iov = {const_cast(body.data()), body.length()}; QuicMemSliceStorage storage(&iov, 1, &allocator, 1024); auto span = storage.ToSpan(); - EXPECT_EQ("ccc", span.GetData(0)); - EXPECT_NE(static_cast(span.GetData(0).data()), body.data()); + EXPECT_EQ("ccc", span[0].AsStringView()); + EXPECT_NE(static_cast(span[0].data()), body.data()); } TEST_F(QuicMemSliceStorageImplTest, MultipleIovInSingleSlice) { @@ -41,7 +40,7 @@ TEST_F(QuicMemSliceStorageImplTest, MultipleIovInSingleSlice) { QuicMemSliceStorage storage(iov, 2, &allocator, 1024); auto span = storage.ToSpan(); - EXPECT_EQ("aaabbbb", span.GetData(0)); + EXPECT_EQ("aaabbbb", span[0].AsStringView()); } TEST_F(QuicMemSliceStorageImplTest, MultipleIovInMultipleSlice) { @@ -53,26 +52,8 @@ TEST_F(QuicMemSliceStorageImplTest, MultipleIovInMultipleSlice) { QuicMemSliceStorage storage(iov, 2, &allocator, 4); auto span = storage.ToSpan(); - EXPECT_EQ("aaaa", span.GetData(0)); - EXPECT_EQ("bbbb", span.GetData(1)); -} - -TEST_F(QuicMemSliceStorageImplTest, AppendMemSlices) { - std::string body1(3, 'a'); - std::string body2(4, 'b'); - std::vector> buffers; - buffers.push_back( - std::make_pair(const_cast(body1.data()), body1.length())); - buffers.push_back( - std::make_pair(const_cast(body2.data()), body2.length())); - QuicTestMemSliceVector mem_slices(buffers); - - QuicMemSliceStorage storage(nullptr, 0, nullptr, 0); - mem_slices.span().ConsumeAll( - [&storage](QuicMemSlice slice) { storage.Append(std::move(slice)); }); - - EXPECT_EQ("aaa", storage.ToSpan().GetData(0)); - EXPECT_EQ("bbbb", storage.ToSpan().GetData(1)); + EXPECT_EQ("aaaa", span[0].AsStringView()); + EXPECT_EQ("bbbb", span[1].AsStringView()); } } // namespace diff --git a/gquiche/quic/platform/api/quic_mem_slice_test.cc b/gquiche/quic/platform/api/quic_mem_slice_test.cc index 82a0c2c1..07cc3c39 100644 --- a/gquiche/quic/platform/api/quic_mem_slice_test.cc +++ b/gquiche/quic/platform/api/quic_mem_slice_test.cc @@ -5,6 +5,9 @@ #include "gquiche/quic/platform/api/quic_mem_slice.h" #include + +#include "absl/strings/string_view.h" +#include "gquiche/quic/core/quic_buffer_allocator.h" #include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/platform/api/quic_test.h" @@ -56,6 +59,18 @@ TEST_F(QuicMemSliceTest, SliceAllocatedOnHeap) { EXPECT_EQ(moved.length(), used_length); } +TEST_F(QuicMemSliceTest, SliceFromBuffer) { + const absl::string_view kTestString = + "RFC 9000 Release Celebration Memorial Test String"; + auto buffer = QuicBuffer::Copy(&allocator_, kTestString); + QuicMemSlice slice(std::move(buffer)); + + EXPECT_EQ(buffer.data(), nullptr); // NOLINT(bugprone-use-after-move) + EXPECT_EQ(buffer.size(), 0u); + EXPECT_EQ(slice.AsStringView(), kTestString); + EXPECT_EQ(slice.length(), kTestString.length()); +} + } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/platform/api/quic_mock_log.h b/gquiche/quic/platform/api/quic_mock_log.h index fb2639d1..75d419b1 100644 --- a/gquiche/quic/platform/api/quic_mock_log.h +++ b/gquiche/quic/platform/api/quic_mock_log.h @@ -5,7 +5,7 @@ #ifndef QUICHE_QUIC_PLATFORM_API_QUIC_MOCK_LOG_H_ #define QUICHE_QUIC_PLATFORM_API_QUIC_MOCK_LOG_H_ -#include "platform/quic_mock_log_impl.h" +#include "platform/quic_platform_impl/quic_mock_log_impl.h" using QuicMockLog = QuicMockLogImpl; #define CREATE_QUIC_MOCK_LOG(log) CREATE_QUIC_MOCK_LOG_IMPL(log) diff --git a/gquiche/quic/platform/api/quic_mutex.h b/gquiche/quic/platform/api/quic_mutex.h index e5422e08..4ab8c70c 100644 --- a/gquiche/quic/platform/api/quic_mutex.h +++ b/gquiche/quic/platform/api/quic_mutex.h @@ -6,7 +6,7 @@ #define QUICHE_QUIC_PLATFORM_API_QUIC_MUTEX_H_ // TODO(b/178613777): move into the common QUICHE platform. -#include "platform/quiche_platform_impl/quic_mutex_impl.h" +#include "gquiche/common/platform/default/quiche_platform_impl/quic_mutex_impl.h" #define QUIC_EXCLUSIVE_LOCKS_REQUIRED QUIC_EXCLUSIVE_LOCKS_REQUIRED_IMPL #define QUIC_GUARDED_BY QUIC_GUARDED_BY_IMPL diff --git a/gquiche/quic/platform/api/quic_sleep.h b/gquiche/quic/platform/api/quic_sleep.h index 92fc98bc..bb340b51 100644 --- a/gquiche/quic/platform/api/quic_sleep.h +++ b/gquiche/quic/platform/api/quic_sleep.h @@ -6,7 +6,8 @@ #define QUICHE_QUIC_PLATFORM_API_QUIC_SLEEP_H_ #include "gquiche/quic/core/quic_time.h" -#include "platform/quic_platform_impl/quic_sleep_impl.h" +// TODO(b/178613777): move into the common QUICHE platform. +#include "quiche_platform_impl/quiche_sleep_impl.h" namespace quic { diff --git a/gquiche/quic/platform/api/quic_socket_address.cc b/gquiche/quic/platform/api/quic_socket_address.cc index 3de9e186..3852141a 100644 --- a/gquiche/quic/platform/api/quic_socket_address.cc +++ b/gquiche/quic/platform/api/quic_socket_address.cc @@ -15,6 +15,23 @@ namespace quic { +namespace { + +uint32_t HashIP(const QuicIpAddress& ip) { + if (ip.IsIPv4()) { + return ip.GetIPv4().s_addr; + } + if (ip.IsIPv6()) { + auto v6addr = ip.GetIPv6(); + const uint32_t* v6_as_ints = + reinterpret_cast(&v6addr.s6_addr); + return v6_as_ints[0] ^ v6_as_ints[1] ^ v6_as_ints[2] ^ v6_as_ints[3]; + } + return 0; +} + +} // namespace + QuicSocketAddress::QuicSocketAddress(QuicIpAddress address, uint16_t port) : host_(address), port_(port) {} @@ -131,4 +148,11 @@ sockaddr_storage QuicSocketAddress::generic_address() const { return result.storage; } +uint32_t QuicSocketAddress::Hash() const { + uint32_t value = 0; + value ^= HashIP(host_); + value ^= port_ | (port_ << 16); + return value; +} + } // namespace quic diff --git a/gquiche/quic/platform/api/quic_socket_address.h b/gquiche/quic/platform/api/quic_socket_address.h index 849b6821..18152745 100644 --- a/gquiche/quic/platform/api/quic_socket_address.h +++ b/gquiche/quic/platform/api/quic_socket_address.h @@ -37,6 +37,9 @@ class QUIC_EXPORT_PRIVATE QuicSocketAddress { uint16_t port() const; sockaddr_storage generic_address() const; + // Hashes this address to an uint32_t. + uint32_t Hash() const; + private: QuicIpAddress host_; uint16_t port_ = 0; @@ -48,6 +51,13 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +class QUIC_EXPORT_PRIVATE QuicSocketAddressHash { + public: + size_t operator()(QuicSocketAddress const& address) const noexcept { + return address.Hash(); + } +}; + } // namespace quic #endif // QUICHE_QUIC_PLATFORM_API_QUIC_SOCKET_ADDRESS_H_ diff --git a/gquiche/quic/platform/api/quic_test.h b/gquiche/quic/platform/api/quic_test.h index 57b493eb..d432ae84 100644 --- a/gquiche/quic/platform/api/quic_test.h +++ b/gquiche/quic/platform/api/quic_test.h @@ -7,6 +7,7 @@ #include "gquiche/quic/platform/api/quic_logging.h" #include "platform/quic_platform_impl/quic_test_impl.h" +#include "gquiche/common/platform/api/quiche_test.h" using QuicFlagSaver = QuicFlagSaverImpl; @@ -25,9 +26,6 @@ inline std::string QuicGetTestMemoryCachePath() { return QuicGetTestMemoryCachePathImpl(); } -#define EXPECT_QUIC_DEBUG_DEATH(condition, message) \ - EXPECT_QUIC_DEBUG_DEATH_IMPL(condition, message) - #define QUIC_SLOW_TEST(test) QUIC_SLOW_TEST_IMPL(test) #endif // QUICHE_QUIC_PLATFORM_API_QUIC_TEST_H_ diff --git a/gquiche/quic/platform/api/quic_test_loopback.h b/gquiche/quic/platform/api/quic_test_loopback.h index 9277285a..b8f0c428 100644 --- a/gquiche/quic/platform/api/quic_test_loopback.h +++ b/gquiche/quic/platform/api/quic_test_loopback.h @@ -5,7 +5,7 @@ #ifndef QUICHE_QUIC_PLATFORM_API_QUIC_TEST_LOOPBACK_H_ #define QUICHE_QUIC_PLATFORM_API_QUIC_TEST_LOOPBACK_H_ -#include "platform/quic_test_loopback_impl.h" +#include "platform/quic_platform_impl/quic_test_loopback_impl.h" namespace quic { diff --git a/gquiche/quic/platform/api/quic_test_output.h b/gquiche/quic/platform/api/quic_test_output.h index c360f277..fbc410eb 100644 --- a/gquiche/quic/platform/api/quic_test_output.h +++ b/gquiche/quic/platform/api/quic_test_output.h @@ -6,7 +6,7 @@ #define QUICHE_QUIC_PLATFORM_API_QUIC_TEST_OUTPUT_H_ #include "absl/strings/string_view.h" -#include "platform/quic_test_output_impl.h" +#include "platform/quic_platform_impl/quic_test_output_impl.h" namespace quic { diff --git a/gquiche/quic/qbone/bonnet/icmp_reachable.cc b/gquiche/quic/qbone/bonnet/icmp_reachable.cc index 014359d5..9d7faabd 100644 --- a/gquiche/quic/qbone/bonnet/icmp_reachable.cc +++ b/gquiche/quic/qbone/bonnet/icmp_reachable.cc @@ -11,8 +11,8 @@ #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/platform/api/quic_mutex.h" #include "gquiche/quic/qbone/platform/icmp_packet.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/quiche_endian.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace { diff --git a/gquiche/quic/qbone/bonnet/mock_tun_device.h b/gquiche/quic/qbone/bonnet/mock_tun_device.h index de62d421..a5aebc7f 100644 --- a/gquiche/quic/qbone/bonnet/mock_tun_device.h +++ b/gquiche/quic/qbone/bonnet/mock_tun_device.h @@ -18,6 +18,8 @@ class MockTunDevice : public TunDeviceInterface { MOCK_METHOD(bool, Down, (), (override)); + MOCK_METHOD(void, CloseDevice, (), (override)); + MOCK_METHOD(int, GetFileDescriptor, (), (const, override)); }; diff --git a/gquiche/quic/qbone/bonnet/tun_device.cc b/gquiche/quic/qbone/bonnet/tun_device.cc index 6c04146d..61325865 100644 --- a/gquiche/quic/qbone/bonnet/tun_device.cc +++ b/gquiche/quic/qbone/bonnet/tun_device.cc @@ -10,6 +10,7 @@ #include #include +#include "absl/cleanup/cleanup.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/qbone/platform/kernel_interface.h" @@ -23,26 +24,25 @@ namespace quic { const int kInvalidFd = -1; -TunDevice::TunDevice(const std::string& interface_name, - int mtu, - bool persist, - bool setup_tun, - KernelInterface* kernel) +TunTapDevice::TunTapDevice(const std::string& interface_name, int mtu, + bool persist, bool setup_tun, bool is_tap, + KernelInterface* kernel) : interface_name_(interface_name), mtu_(mtu), persist_(persist), setup_tun_(setup_tun), + is_tap_(is_tap), file_descriptor_(kInvalidFd), kernel_(*kernel) {} -TunDevice::~TunDevice() { +TunTapDevice::~TunTapDevice() { if (!persist_) { Down(); } - CleanUpFileDescriptor(); + CloseDevice(); } -bool TunDevice::Init() { +bool TunTapDevice::Init() { if (interface_name_.empty() || interface_name_.size() >= IFNAMSIZ) { QUIC_BUG(quic_bug_10995_1) << "interface_name must be nonempty and shorter than " << IFNAMSIZ; @@ -62,47 +62,43 @@ bool TunDevice::Init() { // TODO(pengg): might be better to use netlink socket, once we have a library to // use -bool TunDevice::Up() { - if (setup_tun_ && !is_interface_up_) { - struct ifreq if_request; - memset(&if_request, 0, sizeof(if_request)); - // copy does not zero-terminate the result string, but we've memset the - // entire struct. - interface_name_.copy(if_request.ifr_name, IFNAMSIZ); - if_request.ifr_flags = IFF_UP; - - is_interface_up_ = - NetdeviceIoctl(SIOCSIFFLAGS, reinterpret_cast(&if_request)); - return is_interface_up_; - } else { +bool TunTapDevice::Up() { + if (!setup_tun_) { return true; } + struct ifreq if_request; + memset(&if_request, 0, sizeof(if_request)); + // copy does not zero-terminate the result string, but we've memset the + // entire struct. + interface_name_.copy(if_request.ifr_name, IFNAMSIZ); + if_request.ifr_flags = IFF_UP; + + return NetdeviceIoctl(SIOCSIFFLAGS, reinterpret_cast(&if_request)); } // TODO(pengg): might be better to use netlink socket, once we have a library to // use -bool TunDevice::Down() { - if (setup_tun_ && is_interface_up_) { - struct ifreq if_request; - memset(&if_request, 0, sizeof(if_request)); - // copy does not zero-terminate the result string, but we've memset the - // entire struct. - interface_name_.copy(if_request.ifr_name, IFNAMSIZ); - if_request.ifr_flags = 0; - - is_interface_up_ = - !NetdeviceIoctl(SIOCSIFFLAGS, reinterpret_cast(&if_request)); - return !is_interface_up_; - } else { +bool TunTapDevice::Down() { + if (!setup_tun_) { return true; } -} + struct ifreq if_request; + memset(&if_request, 0, sizeof(if_request)); + // copy does not zero-terminate the result string, but we've memset the + // entire struct. + interface_name_.copy(if_request.ifr_name, IFNAMSIZ); + if_request.ifr_flags = 0; -int TunDevice::GetFileDescriptor() const { - return file_descriptor_; + return NetdeviceIoctl(SIOCSIFFLAGS, reinterpret_cast(&if_request)); } -bool TunDevice::OpenDevice() { +int TunTapDevice::GetFileDescriptor() const { return file_descriptor_; } + +bool TunTapDevice::OpenDevice() { + if (file_descriptor_ != kInvalidFd) { + CloseDevice(); + } + struct ifreq if_request; memset(&if_request, 0, sizeof(if_request)); // copy does not zero-terminate the result string, but we've memset the entire @@ -114,46 +110,53 @@ bool TunDevice::OpenDevice() { // destroy the device and create a new one, but that deletes any existing // routing associated with the interface, which makes the meaning of the // 'persist' bit ambiguous. - if_request.ifr_flags = IFF_TUN | IFF_MULTI_QUEUE | IFF_NO_PI; + if_request.ifr_flags = IFF_MULTI_QUEUE | IFF_NO_PI; + if (is_tap_) { + if_request.ifr_flags |= IFF_TAP; + } else { + if_request.ifr_flags |= IFF_TUN; + } - // TODO(pengg): port MakeCleanup to quic/platform? This makes the call to - // CleanUpFileDescriptor nicer and less error-prone. // When the device is running with IFF_MULTI_QUEUE set, each call to open will // create a queue which can be used to read/write packets from/to the device. + bool successfully_opened = false; + auto cleanup = absl::MakeCleanup([this, &successfully_opened]() { + if (!successfully_opened) { + CloseDevice(); + } + }); + const std::string tun_device_path = absl::GetFlag(FLAGS_qbone_client_tun_device_path); int fd = kernel_.open(tun_device_path.c_str(), O_RDWR); if (fd < 0) { QUIC_PLOG(WARNING) << "Failed to open " << tun_device_path; - CleanUpFileDescriptor(); - return false; + return successfully_opened; } file_descriptor_ = fd; if (!CheckFeatures(fd)) { - CleanUpFileDescriptor(); - return false; + return successfully_opened; } if (kernel_.ioctl(fd, TUNSETIFF, reinterpret_cast(&if_request)) != 0) { QUIC_PLOG(WARNING) << "Failed to TUNSETIFF on fd(" << fd << ")"; - CleanUpFileDescriptor(); - return false; + return successfully_opened; } if (kernel_.ioctl( fd, TUNSETPERSIST, persist_ ? reinterpret_cast(&if_request) : nullptr) != 0) { QUIC_PLOG(WARNING) << "Failed to TUNSETPERSIST on fd(" << fd << ")"; - CleanUpFileDescriptor(); - return false; + return successfully_opened; } - return true; + successfully_opened = true; + return successfully_opened; } // TODO(pengg): might be better to use netlink socket, once we have a library to // use -bool TunDevice::ConfigureInterface() { +bool TunTapDevice::ConfigureInterface() { if (!setup_tun_) { return true; } @@ -166,14 +169,14 @@ bool TunDevice::ConfigureInterface() { if_request.ifr_mtu = mtu_; if (!NetdeviceIoctl(SIOCSIFMTU, reinterpret_cast(&if_request))) { - CleanUpFileDescriptor(); + CloseDevice(); return false; } return true; } -bool TunDevice::CheckFeatures(int tun_device_fd) { +bool TunTapDevice::CheckFeatures(int tun_device_fd) { unsigned int actual_features; if (kernel_.ioctl(tun_device_fd, TUNGETFEATURES, &actual_features) != 0) { QUIC_PLOG(WARNING) << "Failed to TUNGETFEATURES"; @@ -190,7 +193,7 @@ bool TunDevice::CheckFeatures(int tun_device_fd) { return true; } -bool TunDevice::NetdeviceIoctl(int request, void* argp) { +bool TunTapDevice::NetdeviceIoctl(int request, void* argp) { int fd = kernel_.socket(AF_INET6, SOCK_DGRAM, 0); if (fd < 0) { QUIC_PLOG(WARNING) << "Failed to create AF_INET6 socket."; @@ -206,7 +209,7 @@ bool TunDevice::NetdeviceIoctl(int request, void* argp) { return true; } -void TunDevice::CleanUpFileDescriptor() { +void TunTapDevice::CloseDevice() { if (file_descriptor_ != kInvalidFd) { kernel_.close(file_descriptor_); file_descriptor_ = kInvalidFd; diff --git a/gquiche/quic/qbone/bonnet/tun_device.h b/gquiche/quic/qbone/bonnet/tun_device.h index 0430de20..9f253742 100644 --- a/gquiche/quic/qbone/bonnet/tun_device.h +++ b/gquiche/quic/qbone/bonnet/tun_device.h @@ -13,7 +13,7 @@ namespace quic { -class TunDevice : public TunDeviceInterface { +class TunTapDevice : public TunDeviceInterface { public: // This represents a tun device created in the OS kernel, which is a virtual // network interface that any packets sent to it can be read by a user space @@ -32,13 +32,10 @@ class TunDevice : public TunDeviceInterface { // routing rules go away. // // The caller should own kernel and make sure it outlives this. - TunDevice(const std::string& interface_name, - int mtu, - bool persist, - bool setup_tun, - KernelInterface* kernel); + TunTapDevice(const std::string& interface_name, int mtu, bool persist, + bool setup_tun, bool is_tap, KernelInterface* kernel); - ~TunDevice() override; + ~TunTapDevice() override; // Actually creates/reopens and configures the device. bool Init() override; @@ -49,6 +46,11 @@ class TunDevice : public TunDeviceInterface { // Marks the interface down to stop receiving packets. bool Down() override; + // Closes the open file descriptor for the TUN device (if one exists). + // It is safe to reinitialize and reuse this TunTapDevice after calling + // CloseDevice. + void CloseDevice() override; + // Gets the file descriptor that can be used to send/receive packets. // This returns -1 when the TUN device is in an invalid state. int GetFileDescriptor() const override; @@ -63,10 +65,6 @@ class TunDevice : public TunDeviceInterface { // Checks if the required kernel features exists. bool CheckFeatures(int tun_device_fd); - // Closes the opened file descriptor and makes sure the file descriptor - // is no longer available from GetFileDescriptor; - void CleanUpFileDescriptor(); - // Opens a socket and makes netdevice ioctl call bool NetdeviceIoctl(int request, void* argp); @@ -74,9 +72,9 @@ class TunDevice : public TunDeviceInterface { const int mtu_; const bool persist_; const bool setup_tun_; + const bool is_tap_; int file_descriptor_; KernelInterface& kernel_; - bool is_interface_up_ = false; }; } // namespace quic diff --git a/gquiche/quic/qbone/bonnet/tun_device_controller.cc b/gquiche/quic/qbone/bonnet/tun_device_controller.cc index bb6a701f..51e0c4ad 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_controller.cc +++ b/gquiche/quic/qbone/bonnet/tun_device_controller.cc @@ -49,6 +49,10 @@ bool TunDeviceController::UpdateAddress(const IpRange& desired_range) { if (address_updated) { current_address_ = desired_address; + + for (const auto& cb : address_update_cbs_) { + cb(current_address_); + } } return address_updated; @@ -110,6 +114,19 @@ bool TunDeviceController::UpdateRoutes( return true; } +bool TunDeviceController::UpdateRoutesWithRetries( + const IpRange& desired_range, + const std::vector& desired_routes, + int retries) { + while (retries-- > 0) { + if (UpdateRoutes(desired_range, desired_routes)) { + return true; + } + absl::SleepFor(absl::Milliseconds(100)); + } + return false; +} + bool TunDeviceController::UpdateRules(IpRange desired_range) { if (!absl::GetFlag(FLAGS_qbone_tun_device_replace_default_routing_rules)) { return true; @@ -148,4 +165,9 @@ QuicIpAddress TunDeviceController::current_address() { return current_address_; } +void TunDeviceController::RegisterAddressUpdateCallback( + const std::function& cb) { + address_update_cbs_.push_back(cb); +} + } // namespace quic diff --git a/gquiche/quic/qbone/bonnet/tun_device_controller.h b/gquiche/quic/qbone/bonnet/tun_device_controller.h index 3c251ed4..14ac611c 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_controller.h +++ b/gquiche/quic/qbone/bonnet/tun_device_controller.h @@ -39,6 +39,20 @@ class TunDeviceController { virtual bool UpdateRoutes(const IpRange& desired_range, const std::vector& desired_routes); + // Same as UpdateRoutes, but will wait and retry up to the number of times + // given by |retries| before giving up. This is an unpleasant workaround to + // deal with older kernels that aren't always able to set a route with a + // source address immediately after adding the address to the interface. + // + // TODO(b/179430548): Remove this once we've root-caused the underlying issue. + virtual bool UpdateRoutesWithRetries( + const IpRange& desired_range, + const std::vector& desired_routes, + int retries); + + virtual void RegisterAddressUpdateCallback( + const std::function& cb); + virtual QuicIpAddress current_address(); private: @@ -51,6 +65,8 @@ class TunDeviceController { NetlinkInterface* netlink_; QuicIpAddress current_address_; + + std::vector> address_update_cbs_; }; } // namespace quic diff --git a/gquiche/quic/qbone/bonnet/tun_device_controller_test.cc b/gquiche/quic/qbone/bonnet/tun_device_controller_test.cc index a0090038..787ed1b7 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_controller_test.cc +++ b/gquiche/quic/qbone/bonnet/tun_device_controller_test.cc @@ -44,8 +44,10 @@ class TunDeviceControllerTest : public QuicTest { public: TunDeviceControllerTest() : controller_(kIfname, true, &netlink_), - link_local_range_( - *QboneConstants::TerminatorLocalAddressRange()) {} + link_local_range_(*QboneConstants::TerminatorLocalAddressRange()) { + controller_.RegisterAddressUpdateCallback( + [this](QuicIpAddress address) { notified_address_ = address; }); + } protected: void ExpectLinkInfo(const std::string& interface_name, int ifindex) { @@ -60,6 +62,7 @@ class TunDeviceControllerTest : public QuicTest { MockNetlink netlink_; TunDeviceController controller_; + QuicIpAddress notified_address_; IpRange link_local_range_; }; @@ -77,6 +80,7 @@ TEST_F(TunDeviceControllerTest, AddressAppliedWhenNoneExisted) { .WillOnce(Return(true)); EXPECT_TRUE(controller_.UpdateAddress(kIpRange)); + EXPECT_THAT(notified_address_, Eq(kIpRange.FirstAddressInRange())); } TEST_F(TunDeviceControllerTest, OldAddressesAreRemoved) { @@ -110,6 +114,7 @@ TEST_F(TunDeviceControllerTest, OldAddressesAreRemoved) { .WillOnce(Return(true)); EXPECT_TRUE(controller_.UpdateAddress(kIpRange)); + EXPECT_THAT(notified_address_, Eq(kIpRange.FirstAddressInRange())); } TEST_F(TunDeviceControllerTest, UpdateRoutesRemovedOldRoutes) { diff --git a/gquiche/quic/qbone/bonnet/tun_device_interface.h b/gquiche/quic/qbone/bonnet/tun_device_interface.h index e99c547e..e88efa97 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_interface.h +++ b/gquiche/quic/qbone/bonnet/tun_device_interface.h @@ -12,7 +12,7 @@ namespace quic { // An interface with methods for interacting with a TUN device. class TunDeviceInterface { public: - virtual ~TunDeviceInterface() {} + virtual ~TunDeviceInterface() = default; // Actually creates/reopens and configures the device. virtual bool Init() = 0; @@ -23,6 +23,11 @@ class TunDeviceInterface { // Marks the interface down to stop receiving packets. virtual bool Down() = 0; + // Closes the open file descriptor for the TUN device (if one exists). + // It is safe to reinitialize and reuse this TunTapDevice after calling + // CloseDevice. + virtual void CloseDevice() = 0; + // Gets the file descriptor that can be used to send/receive packets. // This returns -1 when the TUN device is in an invalid state. virtual int GetFileDescriptor() const = 0; diff --git a/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc b/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc index ca692baf..6730d479 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc +++ b/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.cc @@ -4,29 +4,36 @@ #include "gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.h" +#include +#include + #include #include "absl/strings/str_cat.h" +#include "gquiche/quic/qbone/platform/icmp_packet.h" +#include "gquiche/quic/qbone/platform/netlink_interface.h" +#include "gquiche/quic/qbone/qbone_constants.h" namespace quic { TunDevicePacketExchanger::TunDevicePacketExchanger( - int fd, - size_t mtu, - KernelInterface* kernel, - QbonePacketExchanger::Visitor* visitor, - size_t max_pending_packets, - StatsInterface* stats) + size_t mtu, KernelInterface* kernel, NetlinkInterface* netlink, + QbonePacketExchanger::Visitor* visitor, size_t max_pending_packets, + bool is_tap, StatsInterface* stats, absl::string_view ifname) : QbonePacketExchanger(visitor, max_pending_packets), - fd_(fd), mtu_(mtu), kernel_(kernel), - stats_(stats) {} + netlink_(netlink), + ifname_(ifname), + is_tap_(is_tap), + stats_(stats) { + if (is_tap_) { + mtu_ += ETH_HLEN; + } +} -bool TunDevicePacketExchanger::WritePacket(const char* packet, - size_t size, - bool* blocked, - std::string* error) { +bool TunDevicePacketExchanger::WritePacket(const char* packet, size_t size, + bool* blocked, std::string* error) { *blocked = false; if (fd_ < 0) { *error = absl::StrCat("Invalid file descriptor of the TUN device: ", fd_); @@ -34,7 +41,11 @@ bool TunDevicePacketExchanger::WritePacket(const char* packet, return false; } - int result = kernel_->write(fd_, packet, size); + auto buffer = std::make_unique(packet, size); + if (is_tap_) { + buffer = ApplyL2Headers(*buffer); + } + int result = kernel_->write(fd_, buffer->data(), buffer->length()); if (result == -1) { if (errno == EWOULDBLOCK || errno == EAGAIN) { // The tunnel is blocked. Note that this does not mean the receive buffer @@ -52,8 +63,7 @@ bool TunDevicePacketExchanger::WritePacket(const char* packet, } std::unique_ptr TunDevicePacketExchanger::ReadPacket( - bool* blocked, - std::string* error) { + bool* blocked, std::string* error) { *blocked = false; if (fd_ < 0) { *error = absl::StrCat("Invalid file descriptor of the TUN device: ", fd_); @@ -74,17 +84,147 @@ std::unique_ptr TunDevicePacketExchanger::ReadPacket( } return nullptr; } - stats_->OnPacketRead(result); - return std::make_unique(read_buffer.release(), result, true); -} -int TunDevicePacketExchanger::file_descriptor() const { - return fd_; + auto buffer = std::make_unique(read_buffer.release(), result, true); + if (is_tap_) { + buffer = ConsumeL2Headers(*buffer); + } + if (buffer) { + stats_->OnPacketRead(buffer->length()); + } + return buffer; } +void TunDevicePacketExchanger::set_file_descriptor(int fd) { fd_ = fd; } + const TunDevicePacketExchanger::StatsInterface* TunDevicePacketExchanger::stats_interface() const { return stats_; } +std::unique_ptr TunDevicePacketExchanger::ApplyL2Headers( + const QuicData& l3_packet) { + if (is_tap_ && !mac_initialized_) { + NetlinkInterface::LinkInfo link_info{}; + if (netlink_->GetLinkInfo(ifname_, &link_info)) { + memcpy(tap_mac_, link_info.hardware_address, ETH_ALEN); + mac_initialized_ = true; + } else { + QUIC_LOG_EVERY_N_SEC(ERROR, 30) + << "Unable to get link info for: " << ifname_; + } + } + + const auto l2_packet_size = l3_packet.length() + ETH_HLEN; + auto l2_buffer = std::make_unique(l2_packet_size); + + // Populate the Ethernet header + auto* hdr = reinterpret_cast(l2_buffer.get()); + // Set src & dst to my own address + memcpy(hdr->h_dest, tap_mac_, ETH_ALEN); + memcpy(hdr->h_source, tap_mac_, ETH_ALEN); + // Assume ipv6 for now + // TODO(b/195113643): Support additional protocols. + hdr->h_proto = absl::ghtons(ETH_P_IPV6); + + // Copy the l3 packet into buffer, just after the ethernet header. + memcpy(l2_buffer.get() + ETH_HLEN, l3_packet.data(), l3_packet.length()); + + return std::make_unique(l2_buffer.release(), l2_packet_size, true); +} + +std::unique_ptr TunDevicePacketExchanger::ConsumeL2Headers( + const QuicData& l2_packet) { + if (l2_packet.length() < ETH_HLEN) { + // Packet is too short for ethernet headers. Drop it. + return nullptr; + } + auto* hdr = reinterpret_cast(l2_packet.data()); + if (hdr->h_proto != absl::ghtons(ETH_P_IPV6)) { + return nullptr; + } + constexpr auto kIp6PrefixLen = ETH_HLEN + sizeof(ip6_hdr); + constexpr auto kIcmp6PrefixLen = kIp6PrefixLen + sizeof(icmp6_hdr); + if (l2_packet.length() < kIp6PrefixLen) { + // Packet is too short to be ipv6. Drop it. + return nullptr; + } + auto* ip_hdr = reinterpret_cast(l2_packet.data() + ETH_HLEN); + const bool is_icmp = ip_hdr->ip6_ctlun.ip6_un1.ip6_un1_nxt == IPPROTO_ICMPV6; + + bool is_neighbor_solicit = false; + if (is_icmp) { + if (l2_packet.length() < kIcmp6PrefixLen) { + // Packet is too short to be icmp6. Drop it. + return nullptr; + } + is_neighbor_solicit = + reinterpret_cast(l2_packet.data() + kIp6PrefixLen) + ->icmp6_type == ND_NEIGHBOR_SOLICIT; + } + + if (is_neighbor_solicit) { + // If we've received a neighbor solicitation, craft an advertisement to + // respond with and write it back to the local interface. + auto* icmp6_payload = l2_packet.data() + kIcmp6PrefixLen; + + QuicIpAddress target_address( + *reinterpret_cast(icmp6_payload)); + if (target_address != *QboneConstants::GatewayAddress()) { + // Only respond to solicitations for our gateway address + return nullptr; + } + + // Neighbor Advertisement crafted per: + // https://datatracker.ietf.org/doc/html/rfc4861#section-4.4 + // + // Using the Target link-layer address option defined at: + // https://datatracker.ietf.org/doc/html/rfc4861#section-4.6.1 + constexpr size_t kIcmpv6OptionSize = 8; + const int payload_size = sizeof(in6_addr) + kIcmpv6OptionSize; + auto payload = std::make_unique(payload_size); + // Place the solicited IPv6 address at the beginning of the response payload + memcpy(payload.get(), icmp6_payload, sizeof(in6_addr)); + // Setup the Target link-layer address option: + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + // | Type | Length | Link-Layer Address ... + // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + int pos = sizeof(in6_addr); + payload[pos++] = ND_OPT_TARGET_LINKADDR; // Type + payload[pos++] = 1; // Length in units of 8 octets + memcpy(&payload[pos], tap_mac_, ETH_ALEN); // This interfaces' MAC address + + // Populate the ICMPv6 header + icmp6_hdr response_hdr{}; + response_hdr.icmp6_type = ND_NEIGHBOR_ADVERT; + // Set the solicited bit to true + response_hdr.icmp6_dataun.icmp6_un_data8[0] = 64; + // Craft the full ICMPv6 packet and then ship it off to WritePacket + // to have it frame it with L2 headers and send it back to the requesting + // neighbor. + CreateIcmpPacket(ip_hdr->ip6_src, ip_hdr->ip6_src, response_hdr, + absl::string_view(payload.get(), payload_size), + [this](absl::string_view packet) { + bool blocked; + std::string error; + WritePacket(packet.data(), packet.size(), &blocked, + &error); + }); + // Do not forward the neighbor solicitation through the tunnel since it's + // link-local. + return nullptr; + } + + // If this isn't a Neighbor Solicitation, remove the L2 headers and forward + // it as though it were an L3 packet. + const auto l3_packet_size = l2_packet.length() - ETH_HLEN; + auto shift_buffer = std::make_unique(l3_packet_size); + memcpy(shift_buffer.get(), l2_packet.data() + ETH_HLEN, l3_packet_size); + + return std::make_unique(shift_buffer.release(), l3_packet_size, + true); +} + } // namespace quic diff --git a/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.h b/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.h index 84c1b615..f5d57b67 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.h +++ b/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger.h @@ -5,8 +5,11 @@ #ifndef QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_PACKET_EXCHANGER_H_ #define QUICHE_QUIC_QBONE_BONNET_TUN_DEVICE_PACKET_EXCHANGER_H_ +#include + #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/qbone/platform/kernel_interface.h" +#include "gquiche/quic/qbone/platform/netlink_interface.h" #include "gquiche/quic/qbone/qbone_client_interface.h" #include "gquiche/quic/qbone/qbone_packet_exchanger.h" @@ -35,8 +38,6 @@ class TunDevicePacketExchanger : public QbonePacketExchanger { ABSL_MUST_USE_RESULT virtual int64_t PacketsWritten() const = 0; }; - // |fd| is a open file descriptor on a TUN device that's opened for both read - // and write. // |mtu| is the mtu of the TUN device. // |kernel| is not owned but should out live objects of this class. // |visitor| is not owned but should out live objects of this class. @@ -44,14 +45,13 @@ class TunDevicePacketExchanger : public QbonePacketExchanger { // the TUN device become blocked. // |stats| is notified about packet read/write statistics. It is not owned, // but should outlive objects of this class. - TunDevicePacketExchanger(int fd, - size_t mtu, - KernelInterface* kernel, + TunDevicePacketExchanger(size_t mtu, KernelInterface* kernel, + NetlinkInterface* netlink, QbonePacketExchanger::Visitor* visitor, - size_t max_pending_packets, - StatsInterface* stats); + size_t max_pending_packets, bool is_tap, + StatsInterface* stats, absl::string_view ifname); - ABSL_MUST_USE_RESULT int file_descriptor() const; + void set_file_descriptor(int fd); ABSL_MUST_USE_RESULT const StatsInterface* stats_interface() const; @@ -66,9 +66,19 @@ class TunDevicePacketExchanger : public QbonePacketExchanger { bool* blocked, std::string* error) override; + std::unique_ptr ApplyL2Headers(const QuicData& l3_packet); + + std::unique_ptr ConsumeL2Headers(const QuicData& l2_packet); + int fd_ = -1; size_t mtu_; KernelInterface* kernel_; + NetlinkInterface* netlink_; + const std::string ifname_; + + const bool is_tap_; + uint8_t tap_mac_[ETH_ALEN]{}; + bool mac_initialized_ = false; StatsInterface* stats_; }; diff --git a/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc b/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc index e56ae0cb..c6808a4f 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc +++ b/gquiche/quic/qbone/bonnet/tun_device_packet_exchanger_test.cc @@ -30,14 +30,13 @@ class MockVisitor : public QbonePacketExchanger::Visitor { class TunDevicePacketExchangerTest : public QuicTest { protected: TunDevicePacketExchangerTest() - : exchanger_(kFd, - kMtu, - &mock_kernel_, - &mock_visitor_, - kMaxPendingPackets, - &mock_stats_) {} - - ~TunDevicePacketExchangerTest() override {} + : exchanger_(kMtu, &mock_kernel_, nullptr, &mock_visitor_, + kMaxPendingPackets, false, &mock_stats_, + absl::string_view()) { + exchanger_.set_file_descriptor(kFd); + } + + ~TunDevicePacketExchangerTest() override = default; MockKernel mock_kernel_; StrictMock mock_visitor_; diff --git a/gquiche/quic/qbone/bonnet/tun_device_test.cc b/gquiche/quic/qbone/bonnet/tun_device_test.cc index 685cbd1c..b4865c2f 100644 --- a/gquiche/quic/qbone/bonnet/tun_device_test.cc +++ b/gquiche/quic/qbone/bonnet/tun_device_test.cc @@ -120,10 +120,10 @@ class TunDeviceTest : public QuicTest { int next_fd_ = 100; }; -// A TunDevice can be initialized and up +// A TunTapDevice can be initialized and up TEST_F(TunDeviceTest, BasicWorkFlow) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ false); - TunDevice tun_device(kDeviceName, 1500, false, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); EXPECT_TRUE(tun_device.Init()); EXPECT_GT(tun_device.GetFileDescriptor(), -1); @@ -136,17 +136,19 @@ TEST_F(TunDeviceTest, FailToOpenTunDevice) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ false); EXPECT_CALL(mock_kernel_, open(StrEq("/dev/net/tun"), _)) .WillOnce(Return(-1)); - TunDevice tun_device(kDeviceName, 1500, false, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); EXPECT_FALSE(tun_device.Init()); EXPECT_EQ(tun_device.GetFileDescriptor(), -1); + ExpectDown(false); } TEST_F(TunDeviceTest, FailToCheckFeature) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ false); EXPECT_CALL(mock_kernel_, ioctl(_, TUNGETFEATURES, _)).WillOnce(Return(-1)); - TunDevice tun_device(kDeviceName, 1500, false, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); EXPECT_FALSE(tun_device.Init()); EXPECT_EQ(tun_device.GetFileDescriptor(), -1); + ExpectDown(false); } TEST_F(TunDeviceTest, TooFewFeature) { @@ -157,15 +159,16 @@ TEST_F(TunDeviceTest, TooFewFeature) { *actual_features = IFF_TUN | IFF_ONE_QUEUE; return 0; })); - TunDevice tun_device(kDeviceName, 1500, false, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, false, true, false, &mock_kernel_); EXPECT_FALSE(tun_device.Init()); EXPECT_EQ(tun_device.GetFileDescriptor(), -1); + ExpectDown(false); } TEST_F(TunDeviceTest, FailToSetFlag) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); EXPECT_CALL(mock_kernel_, ioctl(_, TUNSETIFF, _)).WillOnce(Return(-1)); - TunDevice tun_device(kDeviceName, 1500, true, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); EXPECT_FALSE(tun_device.Init()); EXPECT_EQ(tun_device.GetFileDescriptor(), -1); } @@ -173,7 +176,7 @@ TEST_F(TunDeviceTest, FailToSetFlag) { TEST_F(TunDeviceTest, FailToPersistDevice) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); EXPECT_CALL(mock_kernel_, ioctl(_, TUNSETPERSIST, _)).WillOnce(Return(-1)); - TunDevice tun_device(kDeviceName, 1500, true, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); EXPECT_FALSE(tun_device.Init()); EXPECT_EQ(tun_device.GetFileDescriptor(), -1); } @@ -181,7 +184,7 @@ TEST_F(TunDeviceTest, FailToPersistDevice) { TEST_F(TunDeviceTest, FailToOpenSocket) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); EXPECT_CALL(mock_kernel_, socket(AF_INET6, _, _)).WillOnce(Return(-1)); - TunDevice tun_device(kDeviceName, 1500, true, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); EXPECT_FALSE(tun_device.Init()); EXPECT_EQ(tun_device.GetFileDescriptor(), -1); } @@ -189,14 +192,14 @@ TEST_F(TunDeviceTest, FailToOpenSocket) { TEST_F(TunDeviceTest, FailToSetMtu) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); EXPECT_CALL(mock_kernel_, ioctl(_, SIOCSIFMTU, _)).WillOnce(Return(-1)); - TunDevice tun_device(kDeviceName, 1500, true, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); EXPECT_FALSE(tun_device.Init()); EXPECT_EQ(tun_device.GetFileDescriptor(), -1); } TEST_F(TunDeviceTest, FailToUp) { SetInitExpectations(/* mtu = */ 1500, /* persist = */ true); - TunDevice tun_device(kDeviceName, 1500, true, true, &mock_kernel_); + TunTapDevice tun_device(kDeviceName, 1500, true, true, false, &mock_kernel_); EXPECT_TRUE(tun_device.Init()); EXPECT_GT(tun_device.GetFileDescriptor(), -1); diff --git a/gquiche/quic/qbone/platform/icmp_packet.cc b/gquiche/quic/qbone/platform/icmp_packet.cc index 6fafcb53..38c043cc 100644 --- a/gquiche/quic/qbone/platform/icmp_packet.cc +++ b/gquiche/quic/qbone/platform/icmp_packet.cc @@ -17,7 +17,10 @@ constexpr size_t kIPv6AddressSize = sizeof(in6_addr); constexpr size_t kIPv6HeaderSize = sizeof(ip6_hdr); constexpr size_t kICMPv6HeaderSize = sizeof(icmp6_hdr); constexpr size_t kIPv6MinPacketSize = 1280; -constexpr size_t kIcmpTtl = 64; + +// Hop limit set to 255 to satisfy: +// https://datatracker.ietf.org/doc/html/rfc4861#section-11.2 +constexpr size_t kIcmpTtl = 255; constexpr size_t kICMPv6BodyMaxSize = kIPv6MinPacketSize - kIPv6HeaderSize - kICMPv6HeaderSize; diff --git a/gquiche/quic/qbone/platform/icmp_packet_test.cc b/gquiche/quic/qbone/platform/icmp_packet_test.cc index 2d3638fc..74d91055 100644 --- a/gquiche/quic/qbone/platform/icmp_packet_test.cc +++ b/gquiche/quic/qbone/platform/icmp_packet_test.cc @@ -10,7 +10,7 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace { @@ -37,8 +37,8 @@ constexpr uint8_t kReferenceICMPPacket[] = { 0x00, 0x40, // Next header is 58 0x3a, - // Hop limit is 64 - 0x40, + // Hop limit is 255 + 0xFF, // Source address of fe80:1:2:3:4::1 0xfe, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, diff --git a/gquiche/quic/qbone/platform/netlink.cc b/gquiche/quic/qbone/platform/netlink.cc index 801984d3..1454c52b 100644 --- a/gquiche/quic/qbone/platform/netlink.cc +++ b/gquiche/quic/qbone/platform/netlink.cc @@ -5,6 +5,7 @@ #include "gquiche/quic/qbone/platform/netlink.h" #include + #include #include "absl/base/attributes.h" @@ -12,8 +13,8 @@ #include "gquiche/quic/core/crypto/quic_random.h" #include "gquiche/quic/platform/api/quic_ip_address.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "platform/quic_ip_address_impl.h" #include "gquiche/quic/qbone/platform/rtnetlink_message.h" +#include "gquiche/quic/qbone/qbone_constants.h" namespace quic { @@ -571,10 +572,17 @@ bool Netlink::ChangeRoute(Netlink::Verb verb, // This is the source address to use in the IP packet should this routing rule // is used. if (preferred_source.IsInitialized()) { + auto src_str = preferred_source.ToPackedString(); message.AppendAttribute(RTA_PREFSRC, - reinterpret_cast( - preferred_source.ToPackedString().c_str()), - preferred_source.ToPackedString().size()); + reinterpret_cast(src_str.c_str()), + src_str.size()); + } + + if (verb != Verb::kRemove) { + auto gateway_str = QboneConstants::GatewayAddress()->ToPackedString(); + message.AppendAttribute(RTA_GATEWAY, + reinterpret_cast(gateway_str.c_str()), + gateway_str.size()); } if (!Send(message.BuildIoVec().get(), message.IoVecSize())) { diff --git a/gquiche/quic/qbone/platform/netlink_test.cc b/gquiche/quic/qbone/platform/netlink_test.cc index e5d33163..ef491fdd 100644 --- a/gquiche/quic/qbone/platform/netlink_test.cc +++ b/gquiche/quic/qbone/platform/netlink_test.cc @@ -564,6 +564,15 @@ TEST_F(NetlinkTest, ChangeRouteAdd) { EXPECT_EQ(preferred_ip, address); break; } + case RTA_GATEWAY: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(*QboneConstants::GatewayAddress(), address); + break; + } case RTA_OIF: { ASSERT_EQ(sizeof(int), RTA_PAYLOAD(rta)); const auto* interface_index = @@ -591,7 +600,7 @@ TEST_F(NetlinkTest, ChangeRouteAdd) { } ++num_rta; } - EXPECT_EQ(4, num_rta); + EXPECT_EQ(5, num_rta); }); EXPECT_TRUE(netlink->ChangeRoute( Netlink::Verb::kAdd, QboneConstants::kQboneRouteTableId, subnet, @@ -728,6 +737,15 @@ TEST_F(NetlinkTest, ChangeRouteReplace) { EXPECT_EQ(preferred_ip, address); break; } + case RTA_GATEWAY: { + const auto* raw_address = + reinterpret_cast(RTA_DATA(rta)); + ASSERT_EQ(sizeof(struct in6_addr), RTA_PAYLOAD(rta)); + QuicIpAddress address; + address.FromPackedString(raw_address, RTA_PAYLOAD(rta)); + EXPECT_EQ(*QboneConstants::GatewayAddress(), address); + break; + } case RTA_OIF: { ASSERT_EQ(sizeof(int), RTA_PAYLOAD(rta)); const auto* interface_index = @@ -755,7 +773,7 @@ TEST_F(NetlinkTest, ChangeRouteReplace) { } ++num_rta; } - EXPECT_EQ(4, num_rta); + EXPECT_EQ(5, num_rta); }); EXPECT_TRUE(netlink->ChangeRoute( Netlink::Verb::kReplace, QboneConstants::kQboneRouteTableId, subnet, diff --git a/gquiche/quic/qbone/platform/tcp_packet_test.cc b/gquiche/quic/qbone/platform/tcp_packet_test.cc index 44aa4e1a..4a1c5af6 100644 --- a/gquiche/quic/qbone/platform/tcp_packet_test.cc +++ b/gquiche/quic/qbone/platform/tcp_packet_test.cc @@ -10,7 +10,7 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/platform/api/quic_test.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace { diff --git a/gquiche/quic/qbone/qbone_client_session.cc b/gquiche/quic/qbone/qbone_client_session.cc index 6545ef54..aadd5896 100644 --- a/gquiche/quic/qbone/qbone_client_session.cc +++ b/gquiche/quic/qbone/qbone_client_session.cc @@ -10,6 +10,11 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/qbone/qbone_constants.h" +DEFINE_QUIC_COMMAND_LINE_FLAG( + bool, qbone_client_defer_control_stream_creation, true, + "If true, control stream in QBONE client session is created after " + "encryption established."); + namespace quic { QboneClientSession::QboneClientSession( @@ -34,12 +39,10 @@ std::unique_ptr QboneClientSession::CreateCryptoStream() { /*has_application_state = */ true); } -void QboneClientSession::Initialize() { - // Initialize must be called first, as that's what generates the crypto - // stream. - QboneSessionBase::Initialize(); - static_cast(GetMutableCryptoStream()) - ->CryptoConnect(); +void QboneClientSession::CreateControlStream() { + if (control_stream_ != nullptr) { + return; + } // Register the reserved control stream. QuicStreamId next_id = GetNextOutgoingBidirectionalStreamId(); QUICHE_DCHECK_EQ(next_id, @@ -50,6 +53,26 @@ void QboneClientSession::Initialize() { ActivateStream(std::move(control_stream)); } +void QboneClientSession::Initialize() { + // Initialize must be called first, as that's what generates the crypto + // stream. + QboneSessionBase::Initialize(); + static_cast(GetMutableCryptoStream()) + ->CryptoConnect(); + if (!GetQuicFlag(FLAGS_qbone_client_defer_control_stream_creation)) { + CreateControlStream(); + } +} + +void QboneClientSession::SetDefaultEncryptionLevel( + quic::EncryptionLevel level) { + QboneSessionBase::SetDefaultEncryptionLevel(level); + if (GetQuicFlag(FLAGS_qbone_client_defer_control_stream_creation) && + level == quic::ENCRYPTION_FORWARD_SECURE) { + CreateControlStream(); + } +} + int QboneClientSession::GetNumSentClientHellos() const { return static_cast(GetCryptoStream()) ->num_sent_client_hellos(); diff --git a/gquiche/quic/qbone/qbone_client_session.h b/gquiche/quic/qbone/qbone_client_session.h index f658a3c2..293f8166 100644 --- a/gquiche/quic/qbone/qbone_client_session.h +++ b/gquiche/quic/qbone/qbone_client_session.h @@ -33,6 +33,8 @@ class QUIC_EXPORT_PRIVATE QboneClientSession // QuicSession overrides. This will initiate the crypto stream. void Initialize() override; + // Override to create control stream at FORWARD_SECURE encryption level. + void SetDefaultEncryptionLevel(quic::EncryptionLevel level) override; // Returns the number of client hello messages that have been sent on the // crypto stream. If the handshake has completed then this is one greater @@ -65,6 +67,9 @@ class QUIC_EXPORT_PRIVATE QboneClientSession // QboneSessionBase interface implementation. std::unique_ptr CreateCryptoStream() override; + // Instantiate QboneClientControlStream. + void CreateControlStream(); + // ProofHandler interface implementation. void OnProofValid(const QuicCryptoClientConfig::CachedState& cached) override; void OnProofVerifyDetailsAvailable( @@ -82,7 +87,7 @@ class QUIC_EXPORT_PRIVATE QboneClientSession // Passed to the control stream. QboneClientControlStream::Handler* handler_; // The unowned control stream. - QboneClientControlStream* control_stream_; + QboneClientControlStream* control_stream_ = nullptr; }; } // namespace quic diff --git a/gquiche/quic/qbone/qbone_client_test.cc b/gquiche/quic/qbone/qbone_client_test.cc index ae15abd2..b78ec25e 100644 --- a/gquiche/quic/qbone/qbone_client_test.cc +++ b/gquiche/quic/qbone/qbone_client_test.cc @@ -123,12 +123,10 @@ class QuicQboneDispatcher : public QuicDispatcher { writer_(writer) {} std::unique_ptr CreateQuicSession( - QuicConnectionId id, - const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view alpn, - const quic::ParsedQuicVersion& version, - absl::string_view sni) override { + QuicConnectionId id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, + const ParsedQuicVersion& version, + const ParsedClientHello& /*parsed_chlo*/) override { QUICHE_CHECK_EQ(alpn, "qbone"); QuicConnection* connection = new QuicConnection( id, self_address, peer_address, helper(), alarm_factory(), writer(), diff --git a/gquiche/quic/qbone/qbone_constants.cc b/gquiche/quic/qbone/qbone_constants.cc index 278f24fe..18063461 100644 --- a/gquiche/quic/qbone/qbone_constants.cc +++ b/gquiche/quic/qbone/qbone_constants.cc @@ -19,7 +19,7 @@ QuicStreamId QboneConstants::GetControlStreamId(QuicTransportVersion version) { const QuicIpAddress* QboneConstants::TerminatorLocalAddress() { static auto* terminator_address = []() { - QuicIpAddress* address = new QuicIpAddress; + auto* address = new QuicIpAddress; // 0x71 0x62 0x6f 0x6e 0x65 is 'qbone' in ascii. address->FromString("fe80::71:626f:6e65"); return address; @@ -33,4 +33,13 @@ const IpRange* QboneConstants::TerminatorLocalAddressRange() { return range; } +const QuicIpAddress* QboneConstants::GatewayAddress() { + static auto* gateway_address = []() { + auto* address = new QuicIpAddress; + address->FromString("fe80::1"); + return address; + }(); + return gateway_address; +} + } // namespace quic diff --git a/gquiche/quic/qbone/qbone_constants.h b/gquiche/quic/qbone/qbone_constants.h index 82b4471d..d17e2566 100644 --- a/gquiche/quic/qbone/qbone_constants.h +++ b/gquiche/quic/qbone/qbone_constants.h @@ -25,6 +25,9 @@ struct QboneConstants { static const QuicIpAddress* TerminatorLocalAddress(); // The IPRange containing the TerminatorLocalAddress static const IpRange* TerminatorLocalAddressRange(); + // The gateway address to provide when configuring routes to the QBONE + // interface + static const QuicIpAddress* GatewayAddress(); }; } // namespace quic diff --git a/gquiche/quic/qbone/qbone_control_stream.cc b/gquiche/quic/qbone/qbone_control_stream.cc index 91398592..6ad65f83 100644 --- a/gquiche/quic/qbone/qbone_control_stream.cc +++ b/gquiche/quic/qbone/qbone_control_stream.cc @@ -4,6 +4,8 @@ #include "gquiche/quic/qbone/qbone_control_stream.h" +#include +#include #include "absl/strings/string_view.h" #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" @@ -51,10 +53,10 @@ bool QboneControlStreamBase::SendMessage(const proto2::Message& proto) { QUIC_BUG(quic_bug_11023_1) << "Failed to serialize QboneControlRequest"; return false; } - if (tmp.size() > kuint16max) { + if (tmp.size() > std::numeric_limits::max()) { QUIC_BUG(quic_bug_11023_2) << "QboneControlRequest too large: " << tmp.size() << " > " - << kuint16max; + << std::numeric_limits::max(); return false; } uint16_t size = tmp.size(); diff --git a/gquiche/quic/qbone/qbone_packet_exchanger.cc b/gquiche/quic/qbone/qbone_packet_exchanger.cc index 97839eec..3460f7d7 100644 --- a/gquiche/quic/qbone/qbone_packet_exchanger.cc +++ b/gquiche/quic/qbone/qbone_packet_exchanger.cc @@ -14,7 +14,7 @@ bool QbonePacketExchanger::ReadAndDeliverPacket( std::string error; std::unique_ptr packet = ReadPacket(&blocked, &error); if (packet == nullptr) { - if (!blocked) { + if (!blocked && visitor_) { visitor_->OnReadError(error); } return false; @@ -31,11 +31,14 @@ void QbonePacketExchanger::WritePacketToNetwork(const char* packet, if (WritePacket(packet, size, &blocked, &error)) { return; } - if (!blocked) { - visitor_->OnWriteError(error); - return; + if (blocked) { + write_blocked_ = true; + } else { + QUIC_LOG_EVERY_N_SEC(ERROR, 60) << "Packet write failed: " << error; + if (visitor_) { + visitor_->OnWriteError(error); + } } - write_blocked_ = true; } // Drop the packet on the floor if the queue if full. @@ -58,7 +61,7 @@ void QbonePacketExchanger::SetWritable() { packet_queue_.front()->length(), &blocked, &error)) { packet_queue_.pop_front(); } else { - if (!blocked) { + if (!blocked && visitor_) { visitor_->OnWriteError(error); } write_blocked_ = blocked; diff --git a/gquiche/quic/qbone/qbone_packet_exchanger_test.cc b/gquiche/quic/qbone/qbone_packet_exchanger_test.cc index 51aca646..b2f30deb 100644 --- a/gquiche/quic/qbone/qbone_packet_exchanger_test.cc +++ b/gquiche/quic/qbone/qbone_packet_exchanger_test.cc @@ -251,5 +251,21 @@ TEST(QbonePacketExchangerTest, WriteErrorsGetNotified) { ASSERT_TRUE(exchanger.packets_written().empty()); } +TEST(QbonePacketExchangerTest, NullVisitorDoesntCrash) { + FakeQbonePacketExchanger exchanger(nullptr, kMaxPendingPackets); + MockQboneClient client; + std::string packet = "data"; + + // Force read error. + std::string io_error = "I/O error"; + exchanger.SetReadError(io_error); + EXPECT_FALSE(exchanger.ReadAndDeliverPacket(&client)); + + // Force write error + exchanger.ForceWriteFailure(false, io_error); + exchanger.WritePacketToNetwork(packet.data(), packet.length()); + EXPECT_TRUE(exchanger.packets_written().empty()); +} + } // namespace } // namespace quic diff --git a/gquiche/quic/qbone/qbone_packet_processor.cc b/gquiche/quic/qbone/qbone_packet_processor.cc index c2fcf45f..1aea1086 100644 --- a/gquiche/quic/qbone/qbone_packet_processor.cc +++ b/gquiche/quic/qbone/qbone_packet_processor.cc @@ -103,6 +103,10 @@ void QbonePacketProcessor::ProcessPacket(std::string* packet, SendTcpReset(*packet, direction); stats_->OnPacketDroppedWithTcpReset(direction); break; + case ProcessingResult::TCP_RESET: + SendTcpReset(*packet, direction); + stats_->OnPacketDroppedWithTcpReset(direction); + break; } } diff --git a/gquiche/quic/qbone/qbone_packet_processor.h b/gquiche/quic/qbone/qbone_packet_processor.h index e96e5757..917a9df9 100644 --- a/gquiche/quic/qbone/qbone_packet_processor.h +++ b/gquiche/quic/qbone/qbone_packet_processor.h @@ -46,6 +46,8 @@ class QbonePacketProcessor { // RST requires information from the current connection state to be // well-formed. ICMP_AND_TCP_RESET = 4, + // Send a TCP RST. + TCP_RESET = 5, }; class OutputInterface { diff --git a/gquiche/quic/qbone/qbone_server_session.cc b/gquiche/quic/qbone/qbone_server_session.cc index e9389cf3..f91ee45d 100644 --- a/gquiche/quic/qbone/qbone_server_session.cc +++ b/gquiche/quic/qbone/qbone_server_session.cc @@ -4,6 +4,7 @@ #include "gquiche/quic/qbone/qbone_server_session.h" +#include #include #include "absl/strings/string_view.h" @@ -12,6 +13,11 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/qbone/qbone_constants.h" +DEFINE_QUIC_COMMAND_LINE_FLAG( + bool, qbone_server_defer_control_stream_creation, true, + "If true, control stream in QBONE server session is created after " + "encryption established."); + namespace quic { bool QboneCryptoServerStreamHelper::CanAcceptClientHello( @@ -55,8 +61,10 @@ std::unique_ptr QboneServerSession::CreateCryptoStream() { &stream_helper_); } -void QboneServerSession::Initialize() { - QboneSessionBase::Initialize(); +void QboneServerSession::CreateControlStream() { + if (control_stream_ != nullptr) { + return; + } // Register the reserved control stream. auto control_stream = std::make_unique(this, handler_); @@ -64,6 +72,22 @@ void QboneServerSession::Initialize() { ActivateStream(std::move(control_stream)); } +void QboneServerSession::Initialize() { + QboneSessionBase::Initialize(); + if (!GetQuicFlag(FLAGS_qbone_server_defer_control_stream_creation)) { + CreateControlStream(); + } +} + +void QboneServerSession::SetDefaultEncryptionLevel( + quic::EncryptionLevel level) { + QboneSessionBase::SetDefaultEncryptionLevel(level); + if (GetQuicFlag(FLAGS_qbone_server_defer_control_stream_creation) && + level == quic::ENCRYPTION_FORWARD_SECURE) { + CreateControlStream(); + } +} + bool QboneServerSession::SendClientRequest(const QboneClientRequest& request) { if (!control_stream_) { QUIC_BUG(quic_bug_11026_1) diff --git a/gquiche/quic/qbone/qbone_server_session.h b/gquiche/quic/qbone/qbone_server_session.h index b6739a46..138d6b37 100644 --- a/gquiche/quic/qbone/qbone_server_session.h +++ b/gquiche/quic/qbone/qbone_server_session.h @@ -50,6 +50,8 @@ class QUIC_EXPORT_PRIVATE QboneServerSession ~QboneServerSession() override; void Initialize() override; + // Override to create control stream at FORWARD_SECURE encryption level. + void SetDefaultEncryptionLevel(quic::EncryptionLevel level) override; virtual bool SendClientRequest(const QboneClientRequest& request); @@ -73,6 +75,10 @@ class QUIC_EXPORT_PRIVATE QboneServerSession protected: // QboneSessionBase interface implementation. std::unique_ptr CreateCryptoStream() override; + + // Instantiate QboneServerControlStream. + void CreateControlStream(); + // The packet processor. QbonePacketProcessor processor_; @@ -86,7 +92,7 @@ class QUIC_EXPORT_PRIVATE QboneServerSession // Passed to the control stream. QboneServerControlStream::Handler* handler_; // The unowned control stream. - QboneServerControlStream* control_stream_; + QboneServerControlStream* control_stream_ = nullptr; }; } // namespace quic diff --git a/gquiche/quic/qbone/qbone_session_base.cc b/gquiche/quic/qbone/qbone_session_base.cc index efbf953c..0a33f4f4 100644 --- a/gquiche/quic/qbone/qbone_session_base.cc +++ b/gquiche/quic/qbone/qbone_session_base.cc @@ -140,11 +140,9 @@ void QboneSessionBase::SendPacketToPeer(absl::string_view packet) { } if (send_packets_as_messages_) { - QuicUniqueBufferPtr buffer = MakeUniqueBuffer( - connection()->helper()->GetStreamSendBufferAllocator(), packet.size()); - memcpy(buffer.get(), packet.data(), packet.size()); - QuicMemSlice slice(std::move(buffer), packet.size()); - switch (SendMessage(QuicMemSliceSpan(&slice), /*flush=*/true).status) { + QuicMemSlice slice(QuicBuffer::Copy( + connection()->helper()->GetStreamSendBufferAllocator(), packet)); + switch (SendMessage(absl::MakeSpan(&slice, 1), /*flush=*/true).status) { case MESSAGE_STATUS_SUCCESS: break; case MESSAGE_STATUS_TOO_LARGE: { diff --git a/gquiche/quic/qbone/qbone_session_test.cc b/gquiche/quic/qbone/qbone_session_test.cc index aae3190d..d74f77bd 100644 --- a/gquiche/quic/qbone/qbone_session_test.cc +++ b/gquiche/quic/qbone/qbone_session_test.cc @@ -23,7 +23,6 @@ #include "gquiche/quic/test_tools/quic_connection_peer.h" #include "gquiche/quic/test_tools/quic_session_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { namespace test { @@ -80,9 +79,9 @@ class IndirectionProofSource : public ProofSource { absl::string_view chlo_hash, std::unique_ptr callback) override { if (!proof_source_) { - QuicReferenceCountedPointer chain = - GetCertChain(server_address, client_address, hostname); QuicCryptoProof proof; + QuicReferenceCountedPointer chain = GetCertChain( + server_address, client_address, hostname, &proof.cert_matched_sni); callback->Run(/*ok=*/false, chain, proof, /*details=*/nullptr); return; } @@ -93,13 +92,13 @@ class IndirectionProofSource : public ProofSource { QuicReferenceCountedPointer GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override { + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override { if (!proof_source_) { return QuicReferenceCountedPointer(); } - return proof_source_->GetCertChain(server_address, client_address, - hostname); + return proof_source_->GetCertChain(server_address, client_address, hostname, + cert_matched_sni); } void ComputeTlsSignature( @@ -118,6 +117,14 @@ class IndirectionProofSource : public ProofSource { std::move(callback)); } + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override { + if (!proof_source_) { + return {}; + } + return proof_source_->SupportedTlsSignatureAlgorithms(); + } + TicketCrypter* GetTicketCrypter() override { return nullptr; } private: diff --git a/gquiche/quic/qbone/qbone_stream_test.cc b/gquiche/quic/qbone/qbone_stream_test.cc index 95bb2d07..34565dee 100644 --- a/gquiche/quic/qbone/qbone_stream_test.cc +++ b/gquiche/quic/qbone/qbone_stream_test.cc @@ -31,21 +31,16 @@ using ::testing::StrictMock; class MockQuicSession : public QboneSessionBase { public: MockQuicSession(QuicConnection* connection, const QuicConfig& config) - : QboneSessionBase(connection, - nullptr /*visitor*/, - config, - CurrentSupportedVersions(), - nullptr /*writer*/) {} + : QboneSessionBase(connection, nullptr /*visitor*/, config, + CurrentSupportedVersions(), nullptr /*writer*/) {} ~MockQuicSession() override {} // Writes outgoing data from QuicStream to a string. - QuicConsumedData WritevData(QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, + QuicConsumedData WritevData(QuicStreamId id, size_t write_length, + QuicStreamOffset offset, StreamSendingState state, TransmissionType type, - absl::optional level) override { + EncryptionLevel level) override { if (!writable_) { return QuicConsumedData(0, false); } @@ -58,16 +53,12 @@ class MockQuicSession : public QboneSessionBase { } // Called by QuicStream when they want to close stream. - MOCK_METHOD(void, - MaybeSendRstStreamFrame, - (QuicStreamId stream_id, - QuicRstStreamErrorCode error, + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, QuicStreamOffset bytes_written), (override)); - MOCK_METHOD(void, - MaybeSendStopSendingFrame, - (QuicStreamId stream_id, QuicRstStreamErrorCode error), - (override)); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); // Sets whether data is written to buffer, or else if this is write blocked. void set_writable(bool writable) { writable_ = writable; } @@ -106,8 +97,7 @@ class DummyPacketWriter : public QuicPacketWriter { DummyPacketWriter() {} // QuicPacketWriter overrides. - WriteResult WritePacket(const char* buffer, - size_t buf_len, + WriteResult WritePacket(const char* buffer, size_t buf_len, const QuicIpAddress& self_address, const QuicSocketAddress& peer_address, PerPacketOptions* options) override { @@ -183,8 +173,7 @@ class QboneReadOnlyStreamTest : public ::testing::Test, SimpleBufferAllocator buffer_allocator_; MockClock clock_; const QuicStreamId kStreamId = QuicUtils::GetFirstUnidirectionalStreamId( - CurrentSupportedVersions()[0].transport_version, - Perspective::IS_CLIENT); + CurrentSupportedVersions()[0].transport_version, Perspective::IS_CLIENT); }; // Read an entire string. @@ -242,9 +231,13 @@ TEST_F(QboneReadOnlyStreamTest, ReadBufferedTooLarge) { std::string packet = "0123456789"; int iterations = (QboneConstants::kMaxQbonePacketBytes / packet.size()) + 2; EXPECT_CALL(*session_, MaybeSendStopSendingFrame( - kStreamId, QUIC_BAD_APPLICATION_PAYLOAD)); - EXPECT_CALL(*session_, MaybeSendRstStreamFrame( - kStreamId, QUIC_BAD_APPLICATION_PAYLOAD, _)); + kStreamId, QuicResetStreamError::FromInternal( + QUIC_BAD_APPLICATION_PAYLOAD))); + EXPECT_CALL( + *session_, + MaybeSendRstStreamFrame( + kStreamId, + QuicResetStreamError::FromInternal(QUIC_BAD_APPLICATION_PAYLOAD), _)); for (int i = 0; i < iterations; ++i) { QuicStreamFrame frame(kStreamId, i == (iterations - 1), i * packet.size(), packet); diff --git a/gquiche/quic/quic_transport/quic_transport_client_session.cc b/gquiche/quic/quic_transport/quic_transport_client_session.cc index 1230156e..8706c162 100644 --- a/gquiche/quic/quic_transport/quic_transport_client_session.cc +++ b/gquiche/quic/quic_transport/quic_transport_client_session.cc @@ -242,7 +242,7 @@ void QuicTransportClientSession::SendClientIndication() { } ready_ = true; - visitor_->OnSessionReady(); + visitor_->OnSessionReady(spdy::SpdyHeaderBlock()); } void QuicTransportClientSession::OnMessageReceived(absl::string_view message) { diff --git a/gquiche/quic/quic_transport/quic_transport_client_session.h b/gquiche/quic/quic_transport/quic_transport_client_session.h index 474fd1ee..e15c740d 100644 --- a/gquiche/quic/quic_transport/quic_transport_client_session.h +++ b/gquiche/quic/quic_transport/quic_transport_client_session.h @@ -17,6 +17,7 @@ #include "gquiche/quic/core/quic_crypto_client_stream.h" #include "gquiche/quic/core/quic_crypto_stream.h" #include "gquiche/quic/core/quic_datagram_queue.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_server_id.h" #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_stream.h" @@ -113,6 +114,17 @@ class QUIC_EXPORT_PRIVATE QuicTransportClientSession void OnProofVerifyDetailsAvailable( const ProofVerifyDetails& verify_details) override; + void CloseSession(WebTransportSessionError /*error_code*/, + absl::string_view error_message) override { + connection()->CloseConnection( + QUIC_NO_ERROR, std::string(error_message), + ConnectionCloseBehavior::SEND_CONNECTION_CLOSE_PACKET); + } + + QuicByteCount GetMaxDatagramSize() const override { + return GetGuaranteedLargestMessagePayload(); + } + protected: class QUIC_EXPORT_PRIVATE ClientIndication : public QuicStream { public: @@ -152,8 +164,10 @@ class QUIC_EXPORT_PRIVATE QuicTransportClientSession // has not accepted to a smaller number, by checking the size of // |incoming_bidirectional_streams_| and |incoming_unidirectional_streams_| // before sending MAX_STREAMS. - QuicCircularDeque incoming_bidirectional_streams_; - QuicCircularDeque incoming_unidirectional_streams_; + quiche::QuicheCircularDeque + incoming_bidirectional_streams_; + quiche::QuicheCircularDeque + incoming_unidirectional_streams_; }; } // namespace quic diff --git a/gquiche/quic/quic_transport/quic_transport_client_session_test.cc b/gquiche/quic/quic_transport/quic_transport_client_session_test.cc index edc9b655..27e2d8a5 100644 --- a/gquiche/quic/quic_transport/quic_transport_client_session_test.cc +++ b/gquiche/quic/quic_transport/quic_transport_client_session_test.cc @@ -104,7 +104,7 @@ TEST_F(QuicTransportClientSessionTest, SuccessfulConnection) { "\0\x01" // length "/"; // value - EXPECT_CALL(visitor_, OnSessionReady()); + EXPECT_CALL(visitor_, OnSessionReady(_)); Connect(); EXPECT_TRUE(session_->IsSessionReady()); diff --git a/gquiche/quic/quic_transport/quic_transport_integration_test.cc b/gquiche/quic/quic_transport/quic_transport_integration_test.cc index 2f261645..b4e169f6 100644 --- a/gquiche/quic/quic_transport/quic_transport_integration_test.cc +++ b/gquiche/quic/quic_transport/quic_transport_integration_test.cc @@ -212,7 +212,7 @@ class QuicTransportIntegrationTest : public QuicTest { TEST_F(QuicTransportIntegrationTest, SuccessfulHandshake) { CreateDefaultEndpoints("/discard"); WireUpEndpoints(); - EXPECT_CALL(*client_->visitor(), OnSessionReady()); + EXPECT_CALL(*client_->visitor(), OnSessionReady(_)); RunHandshake(); EXPECT_TRUE(client_->session()->IsSessionReady()); EXPECT_TRUE(server_->session()->IsSessionReady()); diff --git a/gquiche/quic/quic_transport/quic_transport_server_session_test.cc b/gquiche/quic/quic_transport/quic_transport_server_session_test.cc index 84bfcadd..ad9798e8 100644 --- a/gquiche/quic/quic_transport/quic_transport_server_session_test.cc +++ b/gquiche/quic/quic_transport/quic_transport_server_session_test.cc @@ -21,7 +21,7 @@ #include "gquiche/quic/test_tools/crypto_test_utils.h" #include "gquiche/quic/test_tools/quic_test_utils.h" #include "gquiche/quic/test_tools/quic_transport_test_tools.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace test { diff --git a/gquiche/quic/quic_transport/quic_transport_stream.h b/gquiche/quic/quic_transport/quic_transport_stream.h index c79a410f..b9c39c4e 100644 --- a/gquiche/quic/quic_transport/quic_transport_stream.h +++ b/gquiche/quic/quic_transport/quic_transport_stream.h @@ -10,11 +10,11 @@ #include "absl/base/attributes.h" #include "absl/strings/string_view.h" +#include "gquiche/quic/core/http/web_transport_stream_adapter.h" #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_stream.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/web_transport_interface.h" -#include "gquiche/quic/core/web_transport_stream_adapter.h" #include "gquiche/quic/quic_transport/quic_transport_session_interface.h" namespace quic { @@ -50,12 +50,15 @@ class QUIC_EXPORT_PRIVATE QuicTransportStream : public QuicStream, QuicStreamId GetStreamId() const override { return id(); } - void ResetWithUserCode(QuicRstStreamErrorCode error) override { - adapter_.ResetWithUserCode(error); + void ResetWithUserCode(WebTransportStreamError /*error*/) override { + adapter_.ResetWithUserCode(0); } void ResetDueToInternalError() override { adapter_.ResetDueToInternalError(); } + void SendStopSending(WebTransportStreamError /*error*/) override { + adapter_.SendStopSending(0); + } void MaybeResetDueToStreamObjectGone() override { adapter_.MaybeResetDueToStreamObjectGone(); } diff --git a/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.cc b/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.cc index 857d920e..5765b872 100644 --- a/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.cc +++ b/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.cc @@ -7,24 +7,23 @@ #include #include +#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "openssl/sha.h" #include "gquiche/quic/core/crypto/certificate_view.h" #include "gquiche/quic/core/quic_time.h" #include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace { constexpr size_t kFingerprintLength = SHA256_DIGEST_LENGTH * 3 - 1; -constexpr std::array kHexDigits = {'0', '1', '2', '3', '4', '5', - '6', '7', '8', '9', 'a', 'b', - 'c', 'd', 'e', 'f'}; - // Assumes that the character is normalized to lowercase beforehand. bool IsNormalizedHexDigit(char c) { return (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'); @@ -38,31 +37,7 @@ void NormalizeFingerprint(CertificateFingerprint& fingerprint) { } // namespace constexpr char CertificateFingerprint::kSha256[]; - -std::string ComputeSha256Fingerprint(absl::string_view input) { - std::vector raw_hash; - raw_hash.resize(SHA256_DIGEST_LENGTH); - SHA256(reinterpret_cast(input.data()), input.size(), - raw_hash.data()); - - std::string output; - output.resize(kFingerprintLength); - for (size_t i = 0; i < output.size(); i++) { - uint8_t hash_byte = raw_hash[i / 3]; - switch (i % 3) { - case 0: - output[i] = kHexDigits[hash_byte >> 4]; - break; - case 1: - output[i] = kHexDigits[hash_byte & 0xf]; - break; - case 2: - output[i] = ':'; - break; - } - } - return output; -} +constexpr char WebTransportHash::kSha256[]; ProofVerifyDetails* WebTransportFingerprintProofVerifier::Details::Clone() const { @@ -70,8 +45,7 @@ ProofVerifyDetails* WebTransportFingerprintProofVerifier::Details::Clone() } WebTransportFingerprintProofVerifier::WebTransportFingerprintProofVerifier( - const QuicClock* clock, - int max_validity_days) + const QuicClock* clock, int max_validity_days) : clock_(clock), max_validity_days_(max_validity_days), // Add an extra second to max validity to accomodate various edge cases. @@ -105,22 +79,34 @@ bool WebTransportFingerprintProofVerifier::AddFingerprint( } } - fingerprints_.push_back(fingerprint); + std::string normalized = + absl::StrReplaceAll(fingerprint.fingerprint, {{":", ""}}); + hashes_.push_back(WebTransportHash{fingerprint.algorithm, + absl::HexStringToBytes(normalized)}); + return true; +} + +bool WebTransportFingerprintProofVerifier::AddFingerprint( + WebTransportHash hash) { + if (hash.algorithm != CertificateFingerprint::kSha256) { + QUIC_DLOG(WARNING) << "Algorithms other than SHA-256 are not supported"; + return false; + } + if (hash.value.size() != SHA256_DIGEST_LENGTH) { + QUIC_DLOG(WARNING) << "Invalid fingerprint length"; + return false; + } + hashes_.push_back(std::move(hash)); return true; } QuicAsyncStatus WebTransportFingerprintProofVerifier::VerifyProof( - const std::string& /*hostname*/, - const uint16_t /*port*/, + const std::string& /*hostname*/, const uint16_t /*port*/, const std::string& /*server_config*/, - QuicTransportVersion /*transport_version*/, - absl::string_view /*chlo_hash*/, - const std::vector& /*certs*/, - const std::string& /*cert_sct*/, - const std::string& /*signature*/, - const ProofVerifyContext* /*context*/, - std::string* error_details, - std::unique_ptr* details, + QuicTransportVersion /*transport_version*/, absl::string_view /*chlo_hash*/, + const std::vector& /*certs*/, const std::string& /*cert_sct*/, + const std::string& /*signature*/, const ProofVerifyContext* /*context*/, + std::string* error_details, std::unique_ptr* details, std::unique_ptr /*callback*/) { *error_details = "QUIC crypto certificate verification is not supported in " @@ -131,14 +117,10 @@ QuicAsyncStatus WebTransportFingerprintProofVerifier::VerifyProof( } QuicAsyncStatus WebTransportFingerprintProofVerifier::VerifyCertChain( - const std::string& /*hostname*/, - const uint16_t /*port*/, - const std::vector& certs, - const std::string& /*ocsp_response*/, - const std::string& /*cert_sct*/, - const ProofVerifyContext* /*context*/, - std::string* error_details, - std::unique_ptr* details, + const std::string& /*hostname*/, const uint16_t /*port*/, + const std::vector& certs, const std::string& /*ocsp_response*/, + const std::string& /*cert_sct*/, const ProofVerifyContext* /*context*/, + std::string* error_details, std::unique_ptr* details, uint8_t* /*out_alert*/, std::unique_ptr /*callback*/) { if (certs.empty()) { @@ -187,14 +169,14 @@ WebTransportFingerprintProofVerifier::CreateDefaultContext() { bool WebTransportFingerprintProofVerifier::HasKnownFingerprint( absl::string_view der_certificate) { - // https://wicg.github.io/web-transport/#verify-a-certificate-fingerprint - const std::string fingerprint = ComputeSha256Fingerprint(der_certificate); - for (const CertificateFingerprint& reference : fingerprints_) { - if (reference.algorithm != CertificateFingerprint::kSha256) { + // https://w3c.github.io/webtransport/#verify-a-certificate-hash + const std::string hash = RawSha256(der_certificate); + for (const WebTransportHash& reference : hashes_) { + if (reference.algorithm != WebTransportHash::kSha256) { QUIC_BUG(quic_bug_10879_2) << "Unexpected non-SHA-256 hash"; continue; } - if (fingerprint == reference.fingerprint) { + if (hash == reference.value) { return true; } } diff --git a/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.h b/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.h index 366df7c1..dfe7d530 100644 --- a/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.h +++ b/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier.h @@ -11,11 +11,14 @@ #include "gquiche/quic/core/crypto/certificate_view.h" #include "gquiche/quic/core/crypto/proof_verifier.h" #include "gquiche/quic/core/quic_clock.h" +#include "gquiche/quic/platform/api/quic_export.h" namespace quic { // Represents a fingerprint of an X.509 certificate in a format based on // https://w3c.github.io/webrtc-pc/#dom-rtcdtlsfingerprint. +// TODO(vasilvv): remove this once all consumers of this API use +// WebTransportHash. struct QUIC_EXPORT_PRIVATE CertificateFingerprint { static constexpr char kSha256[] = "sha-256"; @@ -27,10 +30,17 @@ struct QUIC_EXPORT_PRIVATE CertificateFingerprint { std::string fingerprint; }; -// Computes a SHA-256 fingerprint of the specified input formatted in the same -// format as CertificateFingerprint::fingerprint would contain. -QUIC_EXPORT_PRIVATE std::string ComputeSha256Fingerprint( - absl::string_view input); +// Represents a fingerprint of an X.509 certificate in a format based on +// https://w3c.github.io/webtransport/#dictdef-webtransporthash. +struct QUIC_EXPORT_PRIVATE WebTransportHash { + static constexpr char kSha256[] = "sha-256"; + + // An algorithm described by one of the names in + // https://www.iana.org/assignments/hash-function-text-names/hash-function-text-names.xhtml + std::string algorithm; + // Raw bytes of the hash. + std::string value; +}; // WebTransportFingerprintProofVerifier verifies the server leaf certificate // against a supplied list of certificate fingerprints following the procedure @@ -76,6 +86,7 @@ class QUIC_EXPORT_PRIVATE WebTransportFingerprintProofVerifier // case-insensitive and are validated internally; the function returns true if // the validation passes. bool AddFingerprint(CertificateFingerprint fingerprint); + bool AddFingerprint(WebTransportHash hash); // ProofVerifier implementation. QuicAsyncStatus VerifyProof( @@ -112,7 +123,7 @@ class QUIC_EXPORT_PRIVATE WebTransportFingerprintProofVerifier const QuicClock* clock_; // Unowned. const int max_validity_days_; const QuicTime::Delta max_validity_; - std::vector fingerprints_; + std::vector hashes_; }; } // namespace quic diff --git a/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier_test.cc b/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier_test.cc index 4e14ecae..d195f1ed 100644 --- a/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier_test.cc +++ b/gquiche/quic/quic_transport/web_transport_fingerprint_proof_verifier_test.cc @@ -6,8 +6,10 @@ #include +#include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "gquiche/quic/core/quic_types.h" +#include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/mock_clock.h" #include "gquiche/quic/test_tools/test_certificates.h" @@ -56,9 +58,8 @@ class WebTransportFingerprintProofVerifierTest : public QuicTest { } void AddTestCertificate() { - EXPECT_TRUE(verifier_->AddFingerprint( - CertificateFingerprint{CertificateFingerprint::kSha256, - ComputeSha256Fingerprint(kTestCertificate)})); + EXPECT_TRUE(verifier_->AddFingerprint(WebTransportHash{ + WebTransportHash::kSha256, RawSha256(kTestCertificate)})); } MockClock clock_; @@ -67,9 +68,9 @@ class WebTransportFingerprintProofVerifierTest : public QuicTest { TEST_F(WebTransportFingerprintProofVerifierTest, Sha256Fingerprint) { // Computed using `openssl x509 -fingerprint -sha256`. - EXPECT_EQ(ComputeSha256Fingerprint(kTestCertificate), - "f2:e5:46:5e:2b:f7:ec:d6:f6:30:66:a5:a3:75:11:73:4a:a0:eb:7c:47:01:" - "0e:86:d6:75:8e:d4:f4:fa:1b:0f"); + EXPECT_EQ(absl::BytesToHexString(RawSha256(kTestCertificate)), + "f2e5465e2bf7ecd6f63066a5a37511734aa0eb7c4701" + "0e86d6758ed4f4fa1b0f"); } TEST_F(WebTransportFingerprintProofVerifierTest, SimpleFingerprint) { @@ -138,9 +139,8 @@ TEST_F(WebTransportFingerprintProofVerifierTest, MaxValidity) { TEST_F(WebTransportFingerprintProofVerifierTest, InvalidCertificate) { constexpr absl::string_view kInvalidCertificate = "Hello, world!"; - ASSERT_TRUE(verifier_->AddFingerprint( - {CertificateFingerprint::kSha256, - ComputeSha256Fingerprint(kInvalidCertificate)})); + ASSERT_TRUE(verifier_->AddFingerprint(WebTransportHash{ + WebTransportHash::kSha256, RawSha256(kInvalidCertificate)})); VerifyResult result = Verify(kInvalidCertificate); EXPECT_EQ(result.status, QUIC_FAILURE); @@ -153,30 +153,29 @@ TEST_F(WebTransportFingerprintProofVerifierTest, AddCertificate) { // Accept all-uppercase fingerprints. verifier_ = std::make_unique( &clock_, /*max_validity_days=*/365); - EXPECT_TRUE(verifier_->AddFingerprint( - {CertificateFingerprint::kSha256, - "F2:E5:46:5E:2B:F7:EC:D6:F6:30:66:A5:A3:75:11:73:4A:A0:EB:" - "7C:47:01:0E:86:D6:75:8E:D4:F4:FA:1B:0F"})); + EXPECT_TRUE(verifier_->AddFingerprint(CertificateFingerprint{ + CertificateFingerprint::kSha256, + "F2:E5:46:5E:2B:F7:EC:D6:F6:30:66:A5:A3:75:11:73:4A:A0:EB:" + "7C:47:01:0E:86:D6:75:8E:D4:F4:FA:1B:0F"})); EXPECT_EQ(Verify(kTestCertificate).detailed_status, WebTransportFingerprintProofVerifier::Status::kValidCertificate); // Reject unknown hash algorithms. - EXPECT_FALSE(verifier_->AddFingerprint( - {"sha-1", - "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"})); + EXPECT_FALSE(verifier_->AddFingerprint(CertificateFingerprint{ + "sha-1", "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"})); // Reject invalid length. EXPECT_FALSE(verifier_->AddFingerprint( - {CertificateFingerprint::kSha256, "00:00:00:00"})); + CertificateFingerprint{CertificateFingerprint::kSha256, "00:00:00:00"})); // Reject missing colons. - EXPECT_FALSE(verifier_->AddFingerprint( - {CertificateFingerprint::kSha256, - "00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00." - "00.00.00.00.00.00.00.00.00.00.00.00.00"})); + EXPECT_FALSE(verifier_->AddFingerprint(CertificateFingerprint{ + CertificateFingerprint::kSha256, + "00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00." + "00.00.00.00.00.00.00.00.00.00.00.00.00"})); // Reject non-hex symbols. - EXPECT_FALSE(verifier_->AddFingerprint( - {CertificateFingerprint::kSha256, - "zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:" - "zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz"})); + EXPECT_FALSE(verifier_->AddFingerprint(CertificateFingerprint{ + CertificateFingerprint::kSha256, + "zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:" + "zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz:zz"})); } } // namespace diff --git a/gquiche/quic/test_tools/crypto_test_utils.cc b/gquiche/quic/test_tools/crypto_test_utils.cc index 91c38b59..d512948c 100644 --- a/gquiche/quic/test_tools/crypto_test_utils.cc +++ b/gquiche/quic/test_tools/crypto_test_utils.cc @@ -40,7 +40,6 @@ #include "gquiche/quic/test_tools/quic_stream_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" #include "gquiche/quic/test_tools/simple_quic_framer.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" #include "gquiche/common/test_tools/quiche_test_utils.h" namespace quic { @@ -229,7 +228,7 @@ int HandshakeWithFakeServer(QuicConfig* server_quic_config, MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, PacketSavingConnection* client_conn, - QuicCryptoClientStream* client, + QuicCryptoClientStreamBase* client, std::string alpn) { auto* server_conn = new testing::NiceMock( helper, alarm_factory, Perspective::IS_SERVER, @@ -594,7 +593,7 @@ void CompareCrypters(const QuicEncrypter* encrypter, } // namespace -void CompareClientAndServerKeys(QuicCryptoClientStream* client, +void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client, QuicCryptoServerStreamBase* server) { QuicFramer* client_framer = QuicConnectionPeer::GetFramer( QuicStreamPeer::session(client)->connection()); @@ -635,22 +634,6 @@ void CompareClientAndServerKeys(QuicCryptoClientStream* client, "subkey secret", client_subkey_secret.data(), client_subkey_secret.length(), server_subkey_secret.data(), server_subkey_secret.length()); - - const char kSampleLabel[] = "label"; - const char kSampleContext[] = "context"; - const size_t kSampleOutputLength = 32; - std::string client_key_extraction; - std::string server_key_extraction; - EXPECT_TRUE(client->ExportKeyingMaterial(kSampleLabel, kSampleContext, - kSampleOutputLength, - &client_key_extraction)); - EXPECT_TRUE(server->ExportKeyingMaterial(kSampleLabel, kSampleContext, - kSampleOutputLength, - &server_key_extraction)); - quiche::test::CompareCharArraysWithHexError( - "sample key extraction", client_key_extraction.data(), - client_key_extraction.length(), server_key_extraction.data(), - server_key_extraction.length()); } QuicTag ParseTag(const char* tagstr) { @@ -744,6 +727,12 @@ void MovePackets(PacketSavingConnection* source_conn, size_t index = *inout_packet_index; for (; index < source_conn->encrypted_packets_.size(); index++) { + if (!dest_conn->connected()) { + QUIC_LOG(INFO) + << "Destination connection disconnected. Skipping packet at index " + << index; + continue; + } // In order to properly test the code we need to perform encryption and // decryption so that the crypters latch when expected. The crypters are in // |dest_conn|, but we don't want to try and use them there. Instead we swap diff --git a/gquiche/quic/test_tools/crypto_test_utils.h b/gquiche/quic/test_tools/crypto_test_utils.h index 16b8093c..535b9c30 100644 --- a/gquiche/quic/test_tools/crypto_test_utils.h +++ b/gquiche/quic/test_tools/crypto_test_utils.h @@ -78,7 +78,7 @@ int HandshakeWithFakeServer(QuicConfig* server_quic_config, MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, PacketSavingConnection* client_conn, - QuicCryptoClientStream* client, + QuicCryptoClientStreamBase* client, std::string alpn); // returns: the number of client hellos that the client sent. @@ -195,7 +195,7 @@ void GenerateFullCHLO( QuicCompressedCertsCache* compressed_certs_cache, CryptoHandshakeMessage* out); -void CompareClientAndServerKeys(QuicCryptoClientStream* client, +void CompareClientAndServerKeys(QuicCryptoClientStreamBase* client, QuicCryptoServerStreamBase* server); // Return a CHLO nonce in hexadecimal. diff --git a/gquiche/quic/test_tools/crypto_test_utils_test.cc b/gquiche/quic/test_tools/crypto_test_utils_test.cc index d020b5b1..66493ddb 100644 --- a/gquiche/quic/test_tools/crypto_test_utils_test.cc +++ b/gquiche/quic/test_tools/crypto_test_utils_test.cc @@ -12,7 +12,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/mock_clock.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" namespace quic { namespace test { diff --git a/gquiche/quic/test_tools/failing_proof_source.cc b/gquiche/quic/test_tools/failing_proof_source.cc index 072f6cba..4353c9f3 100644 --- a/gquiche/quic/test_tools/failing_proof_source.cc +++ b/gquiche/quic/test_tools/failing_proof_source.cc @@ -22,7 +22,9 @@ void FailingProofSource::GetProof(const QuicSocketAddress& /*server_address*/, QuicReferenceCountedPointer FailingProofSource::GetCertChain(const QuicSocketAddress& /*server_address*/, const QuicSocketAddress& /*client_address*/, - const std::string& /*hostname*/) { + const std::string& /*hostname*/, + bool* cert_matched_sni) { + *cert_matched_sni = false; return QuicReferenceCountedPointer(); } diff --git a/gquiche/quic/test_tools/failing_proof_source.h b/gquiche/quic/test_tools/failing_proof_source.h index d3e333f2..6908785d 100644 --- a/gquiche/quic/test_tools/failing_proof_source.h +++ b/gquiche/quic/test_tools/failing_proof_source.h @@ -23,8 +23,8 @@ class FailingProofSource : public ProofSource { QuicReferenceCountedPointer GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; void ComputeTlsSignature( const QuicSocketAddress& server_address, @@ -34,6 +34,11 @@ class FailingProofSource : public ProofSource { absl::string_view in, std::unique_ptr callback) override; + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override { + return {}; + } + TicketCrypter* GetTicketCrypter() override { return nullptr; } }; diff --git a/gquiche/quic/test_tools/fake_proof_source.cc b/gquiche/quic/test_tools/fake_proof_source.cc index 4222327a..3ff8917c 100644 --- a/gquiche/quic/test_tools/fake_proof_source.cc +++ b/gquiche/quic/test_tools/fake_proof_source.cc @@ -96,9 +96,10 @@ void FakeProofSource::GetProof( QuicReferenceCountedPointer FakeProofSource::GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) { - return delegate_->GetCertChain(server_address, client_address, hostname); + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) { + return delegate_->GetCertChain(server_address, client_address, hostname, + cert_matched_sni); } void FakeProofSource::ComputeTlsSignature( @@ -123,6 +124,11 @@ void FakeProofSource::ComputeTlsSignature( std::move(callback), delegate_.get())); } +absl::InlinedVector +FakeProofSource::SupportedTlsSignatureAlgorithms() const { + return delegate_->SupportedTlsSignatureAlgorithms(); +} + ProofSource::TicketCrypter* FakeProofSource::GetTicketCrypter() { if (ticket_crypter_) { return ticket_crypter_.get(); diff --git a/gquiche/quic/test_tools/fake_proof_source.h b/gquiche/quic/test_tools/fake_proof_source.h index 4a166a6b..731dcde2 100644 --- a/gquiche/quic/test_tools/fake_proof_source.h +++ b/gquiche/quic/test_tools/fake_proof_source.h @@ -43,8 +43,8 @@ class FakeProofSource : public ProofSource { std::unique_ptr callback) override; QuicReferenceCountedPointer GetCertChain( const QuicSocketAddress& server_address, - const QuicSocketAddress& client_address, - const std::string& hostname) override; + const QuicSocketAddress& client_address, const std::string& hostname, + bool* cert_matched_sni) override; void ComputeTlsSignature( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, @@ -52,6 +52,8 @@ class FakeProofSource : public ProofSource { uint16_t signature_algorithm, absl::string_view in, std::unique_ptr callback) override; + absl::InlinedVector SupportedTlsSignatureAlgorithms() + const override; TicketCrypter* GetTicketCrypter() override; // Sets the TicketCrypter to use. If nullptr, the TicketCrypter from diff --git a/gquiche/quic/test_tools/fake_proof_source_handle.cc b/gquiche/quic/test_tools/fake_proof_source_handle.cc index c719d32a..6ea8d443 100644 --- a/gquiche/quic/test_tools/fake_proof_source_handle.cc +++ b/gquiche/quic/test_tools/fake_proof_source_handle.cc @@ -54,49 +54,66 @@ ComputeSignatureResult ComputeSignatureNow( } // namespace FakeProofSourceHandle::FakeProofSourceHandle( - ProofSource* delegate, - ProofSourceHandleCallback* callback, - Action select_cert_action, - Action compute_signature_action) + ProofSource* delegate, ProofSourceHandleCallback* callback, + Action select_cert_action, Action compute_signature_action, + QuicDelayedSSLConfig dealyed_ssl_config) : delegate_(delegate), callback_(callback), select_cert_action_(select_cert_action), - compute_signature_action_(compute_signature_action) {} + compute_signature_action_(compute_signature_action), + dealyed_ssl_config_(dealyed_ssl_config) {} -void FakeProofSourceHandle::CancelPendingOperation() { +void FakeProofSourceHandle::CloseHandle() { select_cert_op_.reset(); compute_signature_op_.reset(); + closed_ = true; } QuicAsyncStatus FakeProofSourceHandle::SelectCertificate( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, + absl::string_view ssl_capabilities, const std::string& hostname, absl::string_view client_hello, const std::string& alpn, + absl::optional alps, const std::vector& quic_transport_params, - const absl::optional>& early_data_context) { - all_select_cert_args_.push_back( - SelectCertArgs(server_address, client_address, hostname, client_hello, - alpn, quic_transport_params, early_data_context)); + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) { + if (select_cert_action_ != Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { + QUICHE_CHECK(!closed_); + } + all_select_cert_args_.push_back(SelectCertArgs( + server_address, client_address, ssl_capabilities, hostname, client_hello, + alpn, alps, quic_transport_params, early_data_context, ssl_config)); if (select_cert_action_ == Action::DELEGATE_ASYNC || select_cert_action_ == Action::FAIL_ASYNC) { select_cert_op_.emplace(delegate_, callback_, select_cert_action_, - all_select_cert_args_.back()); + all_select_cert_args_.back(), dealyed_ssl_config_); return QUIC_PENDING; - } else if (select_cert_action_ == Action::FAIL_SYNC) { - callback()->OnSelectCertificateDone(/*ok=*/false, - /*is_sync=*/true, nullptr); + } else if (select_cert_action_ == Action::FAIL_SYNC || + select_cert_action_ == Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { + callback()->OnSelectCertificateDone( + /*ok=*/false, + /*is_sync=*/true, nullptr, /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/false, dealyed_ssl_config_); return QUIC_FAILURE; } QUICHE_DCHECK(select_cert_action_ == Action::DELEGATE_SYNC); + bool cert_matched_sni; QuicReferenceCountedPointer chain = - delegate_->GetCertChain(server_address, client_address, hostname); + delegate_->GetCertChain(server_address, client_address, hostname, + &cert_matched_sni); bool ok = chain && !chain->certs.empty(); - callback_->OnSelectCertificateDone(ok, /*is_sync=*/true, chain.get()); + callback_->OnSelectCertificateDone( + ok, /*is_sync=*/true, chain.get(), + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/cert_matched_sni, dealyed_ssl_config_); return ok ? QUIC_SUCCESS : QUIC_FAILURE; } @@ -107,6 +124,9 @@ QuicAsyncStatus FakeProofSourceHandle::ComputeSignature( uint16_t signature_algorithm, absl::string_view in, size_t max_signature_size) { + if (compute_signature_action_ != Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { + QUICHE_CHECK(!closed_); + } all_compute_signature_args_.push_back( ComputeSignatureArgs(server_address, client_address, hostname, signature_algorithm, in, max_signature_size)); @@ -117,7 +137,9 @@ QuicAsyncStatus FakeProofSourceHandle::ComputeSignature( compute_signature_action_, all_compute_signature_args_.back()); return QUIC_PENDING; - } else if (compute_signature_action_ == Action::FAIL_SYNC) { + } else if (compute_signature_action_ == Action::FAIL_SYNC || + compute_signature_action_ == + Action::FAIL_SYNC_DO_NOT_CHECK_CLOSED) { callback()->OnComputeSignatureDone(/*ok=*/false, /*is_sync=*/true, /*signature=*/"", /*details=*/nullptr); return QUIC_FAILURE; @@ -159,22 +181,31 @@ int FakeProofSourceHandle::NumPendingOperations() const { } FakeProofSourceHandle::SelectCertOperation::SelectCertOperation( - ProofSource* delegate, - ProofSourceHandleCallback* callback, - Action action, - SelectCertArgs args) - : PendingOperation(delegate, callback, action), args_(std::move(args)) {} + ProofSource* delegate, ProofSourceHandleCallback* callback, Action action, + SelectCertArgs args, QuicDelayedSSLConfig dealyed_ssl_config) + : PendingOperation(delegate, callback, action), + args_(std::move(args)), + dealyed_ssl_config_(dealyed_ssl_config) {} void FakeProofSourceHandle::SelectCertOperation::Run() { if (action_ == Action::FAIL_ASYNC) { - callback_->OnSelectCertificateDone(/*ok=*/false, - /*is_sync=*/false, nullptr); + callback_->OnSelectCertificateDone( + /*ok=*/false, + /*is_sync=*/false, nullptr, + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/false, dealyed_ssl_config_); } else if (action_ == Action::DELEGATE_ASYNC) { + bool cert_matched_sni; QuicReferenceCountedPointer chain = delegate_->GetCertChain(args_.server_address, args_.client_address, - args_.hostname); + args_.hostname, &cert_matched_sni); bool ok = chain && !chain->certs.empty(); - callback_->OnSelectCertificateDone(ok, /*is_sync=*/false, chain.get()); + callback_->OnSelectCertificateDone( + ok, /*is_sync=*/false, chain.get(), + /*handshake_hints=*/absl::string_view(), + /*ticket_encryption_key=*/absl::string_view(), + /*cert_matched_sni=*/cert_matched_sni, dealyed_ssl_config_); } else { QUIC_BUG(quic_bug_10139_1) << "Unexpected action: " << static_cast(action_); diff --git a/gquiche/quic/test_tools/fake_proof_source_handle.h b/gquiche/quic/test_tools/fake_proof_source_handle.h index 58a83a96..a81c8497 100644 --- a/gquiche/quic/test_tools/fake_proof_source_handle.h +++ b/gquiche/quic/test_tools/fake_proof_source_handle.h @@ -25,26 +25,32 @@ class FakeProofSourceHandle : public ProofSourceHandle { // Handle the operation asynchronously. Fail the operation when the caller // calls CompletePendingOperation(). FAIL_ASYNC, + // Similar to FAIL_SYNC, but do not QUICHE_CHECK(!closed_) when invoked. + FAIL_SYNC_DO_NOT_CHECK_CLOSED, }; // |delegate| must do cert selection and signature synchronously. - FakeProofSourceHandle(ProofSource* delegate, - ProofSourceHandleCallback* callback, - Action select_cert_action, - Action compute_signature_action); + // |dealyed_ssl_config| is the config passed to OnSelectCertificateDone. + FakeProofSourceHandle( + ProofSource* delegate, ProofSourceHandleCallback* callback, + Action select_cert_action, Action compute_signature_action, + QuicDelayedSSLConfig dealyed_ssl_config = QuicDelayedSSLConfig()); ~FakeProofSourceHandle() override = default; - void CancelPendingOperation() override; + void CloseHandle() override; QuicAsyncStatus SelectCertificate( const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, + absl::string_view ssl_capabilities, const std::string& hostname, absl::string_view client_hello, const std::string& alpn, + absl::optional alps, const std::vector& quic_transport_params, - const absl::optional>& early_data_context) override; + const absl::optional>& early_data_context, + const QuicSSLConfig& ssl_config) override; QuicAsyncStatus ComputeSignature(const QuicSocketAddress& server_address, const QuicSocketAddress& client_address, @@ -62,26 +68,35 @@ class FakeProofSourceHandle : public ProofSourceHandle { struct SelectCertArgs { SelectCertArgs(QuicSocketAddress server_address, QuicSocketAddress client_address, + absl::string_view ssl_capabilities, std::string hostname, absl::string_view client_hello, std::string alpn, + absl::optional alps, std::vector quic_transport_params, - absl::optional> early_data_context) + absl::optional> early_data_context, + QuicSSLConfig ssl_config) : server_address(server_address), client_address(client_address), + ssl_capabilities(ssl_capabilities), hostname(hostname), client_hello(client_hello), alpn(alpn), + alps(alps), quic_transport_params(quic_transport_params), - early_data_context(early_data_context) {} + early_data_context(early_data_context), + ssl_config(ssl_config) {} QuicSocketAddress server_address; QuicSocketAddress client_address; + std::string ssl_capabilities; std::string hostname; std::string client_hello; std::string alpn; + absl::optional alps; std::vector quic_transport_params; absl::optional> early_data_context; + QuicSSLConfig ssl_config; }; struct ComputeSignatureArgs { @@ -133,9 +148,9 @@ class FakeProofSourceHandle : public ProofSourceHandle { class SelectCertOperation : public PendingOperation { public: SelectCertOperation(ProofSource* delegate, - ProofSourceHandleCallback* callback, - Action action, - SelectCertArgs args); + ProofSourceHandleCallback* callback, Action action, + SelectCertArgs args, + QuicDelayedSSLConfig dealyed_ssl_config); ~SelectCertOperation() override = default; @@ -143,6 +158,7 @@ class FakeProofSourceHandle : public ProofSourceHandle { private: const SelectCertArgs args_; + const QuicDelayedSSLConfig dealyed_ssl_config_; }; class ComputeSignatureOperation : public PendingOperation { @@ -163,12 +179,14 @@ class FakeProofSourceHandle : public ProofSourceHandle { private: int NumPendingOperations() const; + bool closed_ = false; ProofSource* delegate_; ProofSourceHandleCallback* callback_; // Action for the next select cert operation. Action select_cert_action_ = Action::DELEGATE_SYNC; // Action for the next compute signature operation. Action compute_signature_action_ = Action::DELEGATE_SYNC; + const QuicDelayedSSLConfig dealyed_ssl_config_; absl::optional select_cert_op_; absl::optional compute_signature_op_; diff --git a/gquiche/quic/test_tools/first_flight.cc b/gquiche/quic/test_tools/first_flight.cc index dd9026e1..d971ef03 100644 --- a/gquiche/quic/test_tools/first_flight.cc +++ b/gquiche/quic/test_tools/first_flight.cc @@ -32,18 +32,28 @@ class FirstFlightExtractor : public DelegatedPacketWriter::Delegate { FirstFlightExtractor(const ParsedQuicVersion& version, const QuicConfig& config, const QuicConnectionId& server_connection_id, - const QuicConnectionId& client_connection_id) + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config) : version_(version), server_connection_id_(server_connection_id), client_connection_id_(client_connection_id), writer_(this), config_(config), - crypto_config_(crypto_test_utils::ProofVerifierForTesting()) { + crypto_config_(std::move(crypto_config)) { EXPECT_NE(version_, UnsupportedQuicVersion()); } + FirstFlightExtractor(const ParsedQuicVersion& version, + const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id) + : FirstFlightExtractor( + version, config, server_connection_id, client_connection_id, + std::make_unique( + crypto_test_utils::ProofVerifierForTesting())) {} + void GenerateFirstFlight() { - crypto_config_.set_alpn(AlpnForVersion(version_)); + crypto_config_->set_alpn(AlpnForVersion(version_)); connection_ = new QuicConnection(server_connection_id_, /*initial_self_address=*/QuicSocketAddress(), @@ -55,7 +65,7 @@ class FirstFlightExtractor : public DelegatedPacketWriter::Delegate { session_ = std::make_unique( config_, ParsedQuicVersionVector{version_}, connection_, // session_ takes ownership of connection_ here. - TestServerId(), &crypto_config_, &push_promise_index_); + TestServerId(), crypto_config_.get(), &push_promise_index_); session_->Initialize(); session_->CryptoConnect(); } @@ -84,13 +94,25 @@ class FirstFlightExtractor : public DelegatedPacketWriter::Delegate { MockAlarmFactory alarm_factory_; DelegatedPacketWriter writer_; QuicConfig config_; - QuicCryptoClientConfig crypto_config_; + std::unique_ptr crypto_config_; QuicClientPushPromiseIndex push_promise_index_; QuicConnection* connection_; // Owned by session_. std::unique_ptr session_; std::vector> packets_; }; +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config) { + FirstFlightExtractor first_flight_extractor( + version, config, server_connection_id, client_connection_id, + std::move(crypto_config)); + first_flight_extractor.GenerateFirstFlight(); + return first_flight_extractor.ConsumePackets(); +} + std::vector> GetFirstFlightOfPackets( const ParsedQuicVersion& version, const QuicConfig& config, diff --git a/gquiche/quic/test_tools/first_flight.h b/gquiche/quic/test_tools/first_flight.h index 374fb443..35e429df 100644 --- a/gquiche/quic/test_tools/first_flight.h +++ b/gquiche/quic/test_tools/first_flight.h @@ -8,6 +8,7 @@ #include #include +#include "gquiche/quic/core/crypto/quic_crypto_client_config.h" #include "gquiche/quic/core/quic_config.h" #include "gquiche/quic/core/quic_connection_id.h" #include "gquiche/quic/core/quic_packet_writer.h" @@ -74,16 +75,23 @@ class QUIC_NO_EXPORT DelegatedPacketWriter : public QuicPacketWriter { // HTTP/3 connection. In most cases, this array will only contain one packet // that carries the CHLO. std::vector> GetFirstFlightOfPackets( - const ParsedQuicVersion& version, - const QuicConfig& config, + const ParsedQuicVersion& version, const QuicConfig& config, const QuicConnectionId& server_connection_id, - const QuicConnectionId& client_connection_id); + const QuicConnectionId& client_connection_id, + std::unique_ptr crypto_config); // Below are various convenience overloads that use default values for the // omitted parameters: // |config| = DefaultQuicConfig(), // |server_connection_id| = TestConnectionId(), // |client_connection_id| = EmptyQuicConnectionId(). +// |crypto_config| = +// QuicCryptoClientConfig(crypto_test_utils::ProofVerifierForTesting()) +std::vector> GetFirstFlightOfPackets( + const ParsedQuicVersion& version, const QuicConfig& config, + const QuicConnectionId& server_connection_id, + const QuicConnectionId& client_connection_id); + std::vector> GetFirstFlightOfPackets( const ParsedQuicVersion& version, const QuicConfig& config, diff --git a/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.cc b/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.cc index 3e4732e4..7cdcc803 100644 --- a/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.cc +++ b/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.cc @@ -18,9 +18,9 @@ MockTimeWaitListManager::MockTimeWaitListManager( : QuicTimeWaitListManager(writer, visitor, clock, alarm_factory) { // Though AddConnectionIdToTimeWait is mocked, we want to retain its // functionality. - EXPECT_CALL(*this, AddConnectionIdToTimeWait(_, _, _)) + EXPECT_CALL(*this, AddConnectionIdToTimeWait(_, _)) .Times(testing::AnyNumber()); - ON_CALL(*this, AddConnectionIdToTimeWait(_, _, _)) + ON_CALL(*this, AddConnectionIdToTimeWait(_, _)) .WillByDefault( Invoke(this, &MockTimeWaitListManager:: QuicTimeWaitListManager_AddConnectionIdToTimeWait)); diff --git a/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.h b/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.h index eba8fd51..9e820489 100644 --- a/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.h +++ b/gquiche/quic/test_tools/mock_quic_time_wait_list_manager.h @@ -19,19 +19,15 @@ class MockTimeWaitListManager : public QuicTimeWaitListManager { QuicAlarmFactory* alarm_factory); ~MockTimeWaitListManager() override; - MOCK_METHOD(void, - AddConnectionIdToTimeWait, - (QuicConnectionId connection_id, - QuicTimeWaitListManager::TimeWaitAction action, + MOCK_METHOD(void, AddConnectionIdToTimeWait, + (QuicTimeWaitListManager::TimeWaitAction action, quic::TimeWaitConnectionInfo info), (override)); void QuicTimeWaitListManager_AddConnectionIdToTimeWait( - QuicConnectionId connection_id, QuicTimeWaitListManager::TimeWaitAction action, quic::TimeWaitConnectionInfo info) { - QuicTimeWaitListManager::AddConnectionIdToTimeWait(connection_id, action, - std::move(info)); + QuicTimeWaitListManager::AddConnectionIdToTimeWait(action, std::move(info)); } MOCK_METHOD(void, diff --git a/gquiche/quic/test_tools/mock_random.cc b/gquiche/quic/test_tools/mock_random.cc index 0d68219a..d305dd22 100644 --- a/gquiche/quic/test_tools/mock_random.cc +++ b/gquiche/quic/test_tools/mock_random.cc @@ -33,5 +33,10 @@ void MockRandom::ChangeValue() { increment_++; } +void MockRandom::ResetBase(uint32_t base) { + base_ = base; + increment_ = 0; +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/mock_random.h b/gquiche/quic/test_tools/mock_random.h index 6c9a3b2e..5da7967e 100644 --- a/gquiche/quic/test_tools/mock_random.h +++ b/gquiche/quic/test_tools/mock_random.h @@ -33,6 +33,9 @@ class MockRandom : public QuicRandom { // |RandUint64| and the byte that |RandBytes| fills with, to change. void ChangeValue(); + // Sets the base to |base| and resets increment to zero. + void ResetBase(uint32_t base); + private: uint32_t base_; uint8_t increment_; diff --git a/gquiche/quic/test_tools/packet_dropping_test_writer.cc b/gquiche/quic/test_tools/packet_dropping_test_writer.cc index de52fd2f..c2aadcc9 100644 --- a/gquiche/quic/test_tools/packet_dropping_test_writer.cc +++ b/gquiche/quic/test_tools/packet_dropping_test_writer.cc @@ -18,7 +18,7 @@ const int32_t kMinSuccesfulWritesAfterPacketLoss = 2; // An alarm that is scheduled if a blocked socket is simulated to indicate // it's writable again. -class WriteUnblockedAlarm : public QuicAlarm::Delegate { +class WriteUnblockedAlarm : public QuicAlarm::DelegateWithoutContext { public: explicit WriteUnblockedAlarm(PacketDroppingTestWriter* writer) : writer_(writer) {} @@ -34,7 +34,7 @@ class WriteUnblockedAlarm : public QuicAlarm::Delegate { // An alarm that is scheduled every time a new packet is to be written at a // later point. -class DelayAlarm : public QuicAlarm::Delegate { +class DelayAlarm : public QuicAlarm::DelegateWithoutContext { public: explicit DelayAlarm(PacketDroppingTestWriter* writer) : writer_(writer) {} @@ -68,7 +68,14 @@ PacketDroppingTestWriter::PacketDroppingTestWriter() simple_random_.set_seed(seed); } -PacketDroppingTestWriter::~PacketDroppingTestWriter() = default; +PacketDroppingTestWriter::~PacketDroppingTestWriter() { + if (write_unblocked_alarm_ != nullptr) { + write_unblocked_alarm_->PermanentCancel(); + } + if (delay_alarm_ != nullptr) { + delay_alarm_->PermanentCancel(); + } +} void PacketDroppingTestWriter::Initialize( QuicConnectionHelperInterface* helper, diff --git a/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc b/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc index b52448a7..372b58d8 100644 --- a/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc +++ b/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.cc @@ -37,7 +37,7 @@ void TestHeadersHandler::OnDecodingCompleted() { } void TestHeadersHandler::OnDecodingErrorDetected( - absl::string_view error_message) { + QuicErrorCode /*error_code*/, absl::string_view error_message) { ASSERT_FALSE(decoding_completed_); ASSERT_FALSE(decoding_error_detected_); diff --git a/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h b/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h index 4d0f80b2..565c2b1d 100644 --- a/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h +++ b/gquiche/quic/test_tools/qpack/qpack_decoder_test_utils.h @@ -10,6 +10,7 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/qpack/qpack_decoder.h" #include "gquiche/quic/core/qpack/qpack_progressive_decoder.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" #include "gquiche/spdy/core/spdy_header_block.h" @@ -51,7 +52,8 @@ class TestHeadersHandler void OnHeaderDecoded(absl::string_view name, absl::string_view value) override; void OnDecodingCompleted() override; - void OnDecodingErrorDetected(absl::string_view error_message) override; + void OnDecodingErrorDetected(QuicErrorCode error_code, + absl::string_view error_message) override; // Release decoded header list. Must only be called if decoding is complete // and no errors have been detected. @@ -81,9 +83,8 @@ class MockHeadersHandler (absl::string_view name, absl::string_view value), (override)); MOCK_METHOD(void, OnDecodingCompleted, (), (override)); - MOCK_METHOD(void, - OnDecodingErrorDetected, - (absl::string_view error_message), + MOCK_METHOD(void, OnDecodingErrorDetected, + (QuicErrorCode error_code, absl::string_view error_message), (override)); }; @@ -95,7 +96,8 @@ class NoOpHeadersHandler void OnHeaderDecoded(absl::string_view /*name*/, absl::string_view /*value*/) override {} void OnDecodingCompleted() override {} - void OnDecodingErrorDetected(absl::string_view /*error_message*/) override {} + void OnDecodingErrorDetected(QuicErrorCode /*error_code*/, + absl::string_view /*error_message*/) override {} }; void QpackDecode( diff --git a/gquiche/quic/test_tools/qpack/qpack_offline_decoder.cc b/gquiche/quic/test_tools/qpack/qpack_offline_decoder.cc index 3528b9c2..07f80a13 100644 --- a/gquiche/quic/test_tools/qpack/qpack_offline_decoder.cc +++ b/gquiche/quic/test_tools/qpack/qpack_offline_decoder.cc @@ -35,10 +35,9 @@ #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "gquiche/quic/core/quic_types.h" -#include "gquiche/quic/platform/api/quic_file_utils.h" #include "gquiche/quic/platform/api/quic_logging.h" #include "gquiche/quic/test_tools/qpack/qpack_test_utils.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/platform/api/quiche_file_utils.h" #include "gquiche/common/quiche_endian.h" namespace quic { @@ -70,8 +69,7 @@ bool QpackOfflineDecoder::DecodeAndVerifyOfflineData( } void QpackOfflineDecoder::OnEncoderStreamError( - QuicErrorCode error_code, - absl::string_view error_message) { + QuicErrorCode error_code, absl::string_view error_message) { QUIC_LOG(ERROR) << "Encoder stream error: " << QuicErrorCodeToString(error_code) << " " << error_message; encoder_stream_error_detected_ = true; @@ -88,12 +86,7 @@ bool QpackOfflineDecoder::ParseInputFilename(absl::string_view input_filename) { auto piece_it = pieces.rbegin(); // Acknowledgement mode: 1 for immediate, 0 for none. - bool immediate_acknowledgement = false; - if (*piece_it == "0") { - immediate_acknowledgement = false; - } else if (*piece_it == "1") { - immediate_acknowledgement = true; - } else { + if (*piece_it != "0" && *piece_it != "1") { QUIC_LOG(ERROR) << "Header acknowledgement field must be 0 or 1 in input filename " << input_filename; @@ -137,9 +130,10 @@ bool QpackOfflineDecoder::DecodeHeaderBlocksFromFile( absl::string_view input_filename) { // Store data in |input_data_storage|; use a absl::string_view to // efficiently keep track of remaining portion yet to be decoded. - std::string input_data_storage; - ReadFileContents(input_filename, &input_data_storage); - absl::string_view input_data(input_data_storage); + absl::optional input_data_storage = + quiche::ReadFileContents(input_filename); + QUICHE_DCHECK(input_data_storage.has_value()); + absl::string_view input_data(*input_data_storage); while (!input_data.empty()) { // Parse stream_id and length. @@ -234,9 +228,10 @@ bool QpackOfflineDecoder::VerifyDecodedHeaderLists( // Store data in |expected_headers_data_storage|; use a // absl::string_view to efficiently keep track of remaining portion // yet to be decoded. - std::string expected_headers_data_storage; - ReadFileContents(expected_headers_filename, &expected_headers_data_storage); - absl::string_view expected_headers_data(expected_headers_data_storage); + absl::optional expected_headers_data_storage = + quiche::ReadFileContents(expected_headers_filename); + QUICHE_DCHECK(expected_headers_data_storage.has_value()); + absl::string_view expected_headers_data(*expected_headers_data_storage); while (!decoded_header_lists_.empty()) { spdy::Http2HeaderBlock decoded_header_list = diff --git a/gquiche/quic/test_tools/quic_client_session_cache_peer.h b/gquiche/quic/test_tools/quic_client_session_cache_peer.h new file mode 100644 index 00000000..bfdd7a0f --- /dev/null +++ b/gquiche/quic/test_tools/quic_client_session_cache_peer.h @@ -0,0 +1,28 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_SESSION_CACHE_PEER_H_ +#define QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_SESSION_CACHE_PEER_H_ + +#include "gquiche/quic/core/crypto/quic_client_session_cache.h" + +namespace quic { +namespace test { + +class QuicClientSessionCachePeer { + public: + static std::string GetToken(QuicClientSessionCache* cache, + const QuicServerId& server_id) { + auto iter = cache->cache_.Lookup(server_id); + if (iter == cache->cache_.end()) { + return {}; + } + return iter->second->token; + } +}; + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_QUIC_CLIENT_SESSION_CACHE_PEER_H_ diff --git a/gquiche/quic/test_tools/quic_connection_peer.cc b/gquiche/quic/test_tools/quic_connection_peer.cc index 913ba2a2..303c1727 100644 --- a/gquiche/quic/test_tools/quic_connection_peer.cc +++ b/gquiche/quic/test_tools/quic_connection_peer.cc @@ -398,11 +398,7 @@ QuicIdleNetworkDetector& QuicConnectionPeer::GetIdleNetworkDetector( void QuicConnectionPeer::SetServerConnectionId( QuicConnection* connection, const QuicConnectionId& server_connection_id) { - if (connection->use_connection_id_on_default_path_) { - connection->default_path_.server_connection_id = server_connection_id; - } else { - connection->server_connection_id_ = server_connection_id; - } + connection->default_path_.server_connection_id = server_connection_id; connection->InstallInitialCrypters(server_connection_id); } @@ -430,7 +426,7 @@ void QuicConnectionPeer::SendPing(QuicConnection* connection) { void QuicConnectionPeer::SetLastPacketDestinationAddress( QuicConnection* connection, const QuicSocketAddress& address) { - connection->last_packet_destination_address_ = address; + connection->last_received_packet_info_.destination_address = address; } // static @@ -483,18 +479,6 @@ QuicByteCount QuicConnectionPeer::BytesReceivedBeforeAddressValidation( return connection->default_path_.bytes_received_before_address_validation; } -// static -void QuicConnectionPeer::EnableMultipleConnectionIdSupport( - QuicConnection* connection) { - connection->support_multiple_connection_ids_ = true; -} - -// static -void QuicConnectionPeer::EnableConnectionMigrationUseNewCID( - QuicConnection* connection) { - connection->connection_migration_use_new_cid_ = true; -} - // static void QuicConnectionPeer::ResetPeerIssuedConnectionIdManager( QuicConnection* connection) { @@ -519,5 +503,37 @@ void QuicConnectionPeer::RetirePeerIssuedConnectionIdsNoLongerOnPath( connection->RetirePeerIssuedConnectionIdsNoLongerOnPath(); } +// static +bool QuicConnectionPeer::HasUnusedPeerIssuedConnectionId( + const QuicConnection* connection) { + return connection->peer_issued_cid_manager_->HasUnusedConnectionId(); +} + +// static +bool QuicConnectionPeer::HasSelfIssuedConnectionIdToConsume( + const QuicConnection* connection) { + return connection->self_issued_cid_manager_->HasConnectionIdToConsume(); +} + +// static +QuicSelfIssuedConnectionIdManager* +QuicConnectionPeer::GetSelfIssuedConnectionIdManager( + QuicConnection* connection) { + return connection->self_issued_cid_manager_.get(); +} + +// static +std::unique_ptr +QuicConnectionPeer::MakeSelfIssuedConnectionIdManager( + QuicConnection* connection) { + return connection->MakeSelfIssuedConnectionIdManager(); +} + +// static +void QuicConnectionPeer::SetLastDecryptedLevel(QuicConnection* connection, + EncryptionLevel level) { + connection->last_decrypted_packet_level_ = level; +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/quic_connection_peer.h b/gquiche/quic/test_tools/quic_connection_peer.h index b189c413..5ca0f00b 100644 --- a/gquiche/quic/test_tools/quic_connection_peer.h +++ b/gquiche/quic/test_tools/quic_connection_peer.h @@ -197,11 +197,6 @@ class QuicConnectionPeer { static QuicByteCount BytesReceivedBeforeAddressValidation( QuicConnection* connection); - static void EnableMultipleConnectionIdSupport(QuicConnection* connection); - - // Remove this method once the boolean is enabled via reloadable flag. - static void EnableConnectionMigrationUseNewCID(QuicConnection* connection); - static void ResetPeerIssuedConnectionIdManager(QuicConnection* connection); static QuicConnection::PathState* GetDefaultPath(QuicConnection* connection); @@ -211,6 +206,20 @@ class QuicConnectionPeer { static void RetirePeerIssuedConnectionIdsNoLongerOnPath( QuicConnection* connection); + + static bool HasUnusedPeerIssuedConnectionId(const QuicConnection* connection); + + static bool HasSelfIssuedConnectionIdToConsume( + const QuicConnection* connection); + + static QuicSelfIssuedConnectionIdManager* GetSelfIssuedConnectionIdManager( + QuicConnection* connection); + + static std::unique_ptr + MakeSelfIssuedConnectionIdManager(QuicConnection* connection); + + static void SetLastDecryptedLevel(QuicConnection* connection, + EncryptionLevel level); }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_crypto_server_config_peer.cc b/gquiche/quic/test_tools/quic_crypto_server_config_peer.cc index 2835e379..503e9a27 100644 --- a/gquiche/quic/test_tools/quic_crypto_server_config_peer.cc +++ b/gquiche/quic/test_tools/quic_crypto_server_config_peer.cc @@ -59,7 +59,7 @@ HandshakeFailureReason QuicCryptoServerConfigPeer::ValidateSourceAddressTokens( CachedNetworkParameters* cached_network_params) { SourceAddressTokens tokens; HandshakeFailureReason reason = server_config_->ParseSourceAddressToken( - *GetConfig(config_id)->source_address_token_boxer, srct, &tokens); + *GetConfig(config_id)->source_address_token_boxer, srct, tokens); if (reason != HANDSHAKE_OK) { return reason; } @@ -75,7 +75,7 @@ QuicCryptoServerConfigPeer::ValidateSingleSourceAddressToken( QuicWallTime now) { SourceAddressTokens tokens; HandshakeFailureReason parse_status = server_config_->ParseSourceAddressToken( - *GetPrimaryConfig()->source_address_token_boxer, token, &tokens); + *GetPrimaryConfig()->source_address_token_boxer, token, tokens); if (HANDSHAKE_OK != parse_status) { return parse_status; } diff --git a/gquiche/quic/test_tools/quic_crypto_server_config_peer.h b/gquiche/quic/test_tools/quic_crypto_server_config_peer.h index 7afde027..b3eea1ff 100644 --- a/gquiche/quic/test_tools/quic_crypto_server_config_peer.h +++ b/gquiche/quic/test_tools/quic_crypto_server_config_peer.h @@ -40,13 +40,10 @@ class QuicCryptoServerConfigPeer { QuicWallTime now, CachedNetworkParameters* cached_network_params); - // Attempts to validate the tokens in |tokens|. + // Attempts to validate the tokens in |srct|. HandshakeFailureReason ValidateSourceAddressTokens( - std::string config_id, - absl::string_view tokens, - const QuicIpAddress& ip, - QuicWallTime now, - CachedNetworkParameters* cached_network_params); + std::string config_id, absl::string_view srct, const QuicIpAddress& ip, + QuicWallTime now, CachedNetworkParameters* cached_network_params); // Attempts to validate the single source address token in |token|. HandshakeFailureReason ValidateSingleSourceAddressToken( diff --git a/gquiche/quic/test_tools/quic_dispatcher_peer.cc b/gquiche/quic/test_tools/quic_dispatcher_peer.cc index 00a914ae..39e3a855 100644 --- a/gquiche/quic/test_tools/quic_dispatcher_peer.cc +++ b/gquiche/quic/test_tools/quic_dispatcher_peer.cc @@ -117,31 +117,26 @@ std::string QuicDispatcherPeer::SelectAlpn( // static QuicSession* QuicDispatcherPeer::GetFirstSessionIfAny( QuicDispatcher* dispatcher) { - if (dispatcher->use_reference_counted_session_map()) { - if (dispatcher->reference_counted_session_map_.empty()) { - return nullptr; - } - return dispatcher->reference_counted_session_map_.begin()->second.get(); - } else { - if (dispatcher->session_map_.empty()) { - return nullptr; - } - return dispatcher->session_map_.begin()->second.get(); + if (dispatcher->reference_counted_session_map_.empty()) { + return nullptr; } + return dispatcher->reference_counted_session_map_.begin()->second.get(); } // static const QuicSession* QuicDispatcherPeer::FindSession( const QuicDispatcher* dispatcher, QuicConnectionId id) { - if (dispatcher->use_reference_counted_session_map()) { - auto it = dispatcher->reference_counted_session_map_.find(id); - return (it == dispatcher->reference_counted_session_map_.end()) - ? nullptr - : it->second.get(); - } - auto it = dispatcher->session_map_.find(id); - return (it == dispatcher->session_map_.end()) ? nullptr : it->second.get(); + auto it = dispatcher->reference_counted_session_map_.find(id); + return (it == dispatcher->reference_counted_session_map_.end()) + ? nullptr + : it->second.get(); +} + +// static +QuicAlarm* QuicDispatcherPeer::GetClearResetAddressesAlarm( + QuicDispatcher* dispatcher) { + return dispatcher->clear_stateless_reset_addresses_alarm_.get(); } } // namespace test diff --git a/gquiche/quic/test_tools/quic_dispatcher_peer.h b/gquiche/quic/test_tools/quic_dispatcher_peer.h index 205859ce..af241229 100644 --- a/gquiche/quic/test_tools/quic_dispatcher_peer.h +++ b/gquiche/quic/test_tools/quic_dispatcher_peer.h @@ -76,6 +76,8 @@ class QuicDispatcherPeer { // Find the corresponding session if exsits. static const QuicSession* FindSession(const QuicDispatcher* dispatcher, QuicConnectionId id); + + static QuicAlarm* GetClearResetAddressesAlarm(QuicDispatcher* dispatcher); }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_framer_peer.cc b/gquiche/quic/test_tools/quic_framer_peer.cc index 3201c1bc..353774e5 100644 --- a/gquiche/quic/test_tools/quic_framer_peer.cc +++ b/gquiche/quic/test_tools/quic_framer_peer.cc @@ -6,7 +6,6 @@ #include "gquiche/quic/core/quic_framer.h" #include "gquiche/quic/core/quic_packets.h" -#include "gquiche/quic/platform/api/quic_map_util.h" namespace quic { namespace test { diff --git a/gquiche/quic/test_tools/quic_packet_creator_peer.cc b/gquiche/quic/test_tools/quic_packet_creator_peer.cc index 9644ccb6..eb801c47 100644 --- a/gquiche/quic/test_tools/quic_packet_creator_peer.cc +++ b/gquiche/quic/test_tools/quic_packet_creator_peer.cc @@ -156,5 +156,11 @@ QuicFrames& QuicPacketCreatorPeer::QueuedFrames(QuicPacketCreator* creator) { return creator->queued_frames_; } +// static +void QuicPacketCreatorPeer::SetRandom(QuicPacketCreator* creator, + QuicRandom* random) { + creator->random_ = random; +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/quic_packet_creator_peer.h b/gquiche/quic/test_tools/quic_packet_creator_peer.h index 239dc604..e2ce91ba 100644 --- a/gquiche/quic/test_tools/quic_packet_creator_peer.h +++ b/gquiche/quic/test_tools/quic_packet_creator_peer.h @@ -10,6 +10,7 @@ namespace quic { class QuicFramer; class QuicPacketCreator; +class QuicRandom; namespace test { @@ -61,6 +62,7 @@ class QuicPacketCreatorPeer { static QuicFramer* framer(QuicPacketCreator* creator); static std::string GetRetryToken(QuicPacketCreator* creator); static QuicFrames& QueuedFrames(QuicPacketCreator* creator); + static void SetRandom(QuicPacketCreator* creator, QuicRandom* random); }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_sent_packet_manager_peer.cc b/gquiche/quic/test_tools/quic_sent_packet_manager_peer.cc index 5de1952b..ede08e08 100644 --- a/gquiche/quic/test_tools/quic_sent_packet_manager_peer.cc +++ b/gquiche/quic/test_tools/quic_sent_packet_manager_peer.cc @@ -214,5 +214,11 @@ bool QuicSentPacketManagerPeer::UsePacketThresholdForRuntPackets( .use_packet_threshold_for_runt_packets(); } +// static +int QuicSentPacketManagerPeer::GetNumPtosForPathDegrading( + QuicSentPacketManager* sent_packet_manager) { + return sent_packet_manager->num_ptos_for_path_degrading_; +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/quic_sent_packet_manager_peer.h b/gquiche/quic/test_tools/quic_sent_packet_manager_peer.h index ea6092ee..02309d85 100644 --- a/gquiche/quic/test_tools/quic_sent_packet_manager_peer.h +++ b/gquiche/quic/test_tools/quic_sent_packet_manager_peer.h @@ -98,6 +98,9 @@ class QuicSentPacketManagerPeer { static bool UsePacketThresholdForRuntPackets( QuicSentPacketManager* sent_packet_manager); + + static int GetNumPtosForPathDegrading( + QuicSentPacketManager* sent_packet_manager); }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_session_peer.cc b/gquiche/quic/test_tools/quic_session_peer.cc index 9c415425..5f900937 100644 --- a/gquiche/quic/test_tools/quic_session_peer.cc +++ b/gquiche/quic/test_tools/quic_session_peer.cc @@ -8,7 +8,6 @@ #include "gquiche/quic/core/quic_session.h" #include "gquiche/quic/core/quic_stream.h" #include "gquiche/quic/core/quic_utils.h" -#include "gquiche/quic/platform/api/quic_map_util.h" namespace quic { namespace test { @@ -152,24 +151,20 @@ bool QuicSessionPeer::IsStreamClosed(QuicSession* session, QuicStreamId id) { // static bool QuicSessionPeer::IsStreamCreated(QuicSession* session, QuicStreamId id) { - return QuicContainsKey(session->stream_map_, id); + return session->stream_map_.contains(id); } // static bool QuicSessionPeer::IsStreamAvailable(QuicSession* session, QuicStreamId id) { if (VersionHasIetfQuicFrames(session->transport_version())) { if (id % QuicUtils::StreamIdDelta(session->transport_version()) < 2) { - return QuicContainsKey( - session->ietf_streamid_manager_.bidirectional_stream_id_manager_ - .available_streams_, - id); + return session->ietf_streamid_manager_.bidirectional_stream_id_manager_ + .available_streams_.contains(id); } - return QuicContainsKey( - session->ietf_streamid_manager_.unidirectional_stream_id_manager_ - .available_streams_, - id); + return session->ietf_streamid_manager_.unidirectional_stream_id_manager_ + .available_streams_.contains(id); } - return QuicContainsKey(session->stream_id_manager_.available_streams_, id); + return session->stream_id_manager_.available_streams_.contains(id); } // static diff --git a/gquiche/quic/test_tools/quic_spdy_session_peer.cc b/gquiche/quic/test_tools/quic_spdy_session_peer.cc index db25707d..048ece6d 100644 --- a/gquiche/quic/test_tools/quic_spdy_session_peer.cc +++ b/gquiche/quic/test_tools/quic_spdy_session_peer.cc @@ -40,18 +40,14 @@ spdy::SpdyFramer* QuicSpdySessionPeer::GetSpdyFramer(QuicSpdySession* session) { } void QuicSpdySessionPeer::SetMaxInboundHeaderListSize( - QuicSpdySession* session, - size_t max_inbound_header_size) { + QuicSpdySession* session, size_t max_inbound_header_size) { session->set_max_inbound_header_list_size(max_inbound_header_size); } // static size_t QuicSpdySessionPeer::WriteHeadersOnHeadersStream( - QuicSpdySession* session, - QuicStreamId id, - spdy::SpdyHeaderBlock headers, - bool fin, - const spdy::SpdyStreamPrecedence& precedence, + QuicSpdySession* session, QuicStreamId id, spdy::SpdyHeaderBlock headers, + bool fin, const spdy::SpdyStreamPrecedence& precedence, QuicReferenceCountedPointer ack_listener) { return session->WriteHeadersOnHeadersStream( id, std::move(headers), fin, precedence, std::move(ack_listener)); @@ -100,17 +96,22 @@ QpackReceiveStream* QuicSpdySessionPeer::GetQpackEncoderReceiveStream( } // static -void QuicSpdySessionPeer::SetH3DatagramSupported(QuicSpdySession* session, - bool h3_datagram_supported) { - session->h3_datagram_supported_ = h3_datagram_supported; +void QuicSpdySessionPeer::SetHttpDatagramSupport( + QuicSpdySession* session, HttpDatagramSupport http_datagram_support) { + session->http_datagram_support_ = http_datagram_support; } // static -void QuicSpdySessionPeer::EnableWebTransport(QuicSpdySession& session) { - SetQuicReloadableFlag(quic_h3_datagram, true); - QUICHE_DCHECK(session.WillNegotiateWebTransport()); - session.h3_datagram_supported_ = true; - session.peer_supports_webtransport_ = true; +HttpDatagramSupport QuicSpdySessionPeer::LocalHttpDatagramSupport( + QuicSpdySession* session) { + return session->LocalHttpDatagramSupport(); +} + +// static +void QuicSpdySessionPeer::EnableWebTransport(QuicSpdySession* session) { + QUICHE_DCHECK(session->WillNegotiateWebTransport()); + SetHttpDatagramSupport(session, HttpDatagramSupport::kDraft04); + session->peer_supports_webtransport_ = true; } } // namespace test diff --git a/gquiche/quic/test_tools/quic_spdy_session_peer.h b/gquiche/quic/test_tools/quic_spdy_session_peer.h index 486572ed..747bc379 100644 --- a/gquiche/quic/test_tools/quic_spdy_session_peer.h +++ b/gquiche/quic/test_tools/quic_spdy_session_peer.h @@ -7,6 +7,7 @@ #include "gquiche/quic/core/http/quic_receive_control_stream.h" #include "gquiche/quic/core/http/quic_send_control_stream.h" +#include "gquiche/quic/core/http/quic_spdy_session.h" #include "gquiche/quic/core/qpack/qpack_receive_stream.h" #include "gquiche/quic/core/qpack/qpack_send_stream.h" #include "gquiche/quic/core/quic_packets.h" @@ -16,7 +17,6 @@ namespace quic { class QuicHeadersStream; -class QuicSpdySession; namespace test { @@ -32,11 +32,8 @@ class QuicSpdySessionPeer { static void SetMaxInboundHeaderListSize(QuicSpdySession* session, size_t max_inbound_header_size); static size_t WriteHeadersOnHeadersStream( - QuicSpdySession* session, - QuicStreamId id, - spdy::SpdyHeaderBlock headers, - bool fin, - const spdy::SpdyStreamPrecedence& precedence, + QuicSpdySession* session, QuicStreamId id, spdy::SpdyHeaderBlock headers, + bool fin, const spdy::SpdyStreamPrecedence& precedence, QuicReferenceCountedPointer ack_listener); // |session| can't be nullptr. static QuicStreamId GetNextOutgoingUnidirectionalStreamId( @@ -50,9 +47,10 @@ class QuicSpdySessionPeer { QuicSpdySession* session); static QpackReceiveStream* GetQpackEncoderReceiveStream( QuicSpdySession* session); - static void SetH3DatagramSupported(QuicSpdySession* session, - bool h3_datagram_supported); - static void EnableWebTransport(QuicSpdySession& session); + static void SetHttpDatagramSupport(QuicSpdySession* session, + HttpDatagramSupport http_datagram_support); + static HttpDatagramSupport LocalHttpDatagramSupport(QuicSpdySession* session); + static void EnableWebTransport(QuicSpdySession* session); }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_spdy_stream_peer.cc b/gquiche/quic/test_tools/quic_spdy_stream_peer.cc index 1314148b..3fa0bf4f 100644 --- a/gquiche/quic/test_tools/quic_spdy_stream_peer.cc +++ b/gquiche/quic/test_tools/quic_spdy_stream_peer.cc @@ -23,5 +23,10 @@ QuicSpdyStreamPeer::unacked_frame_headers_offsets(QuicSpdyStream* stream) { return stream->unacked_frame_headers_offsets_; } +// static +bool QuicSpdyStreamPeer::use_datagram_contexts(QuicSpdyStream* stream) { + return stream->use_datagram_contexts_; +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/quic_spdy_stream_peer.h b/gquiche/quic/test_tools/quic_spdy_stream_peer.h index 00964524..aa4d1319 100644 --- a/gquiche/quic/test_tools/quic_spdy_stream_peer.h +++ b/gquiche/quic/test_tools/quic_spdy_stream_peer.h @@ -23,6 +23,7 @@ class QuicSpdyStreamPeer { QuicReferenceCountedPointer ack_listener); static const QuicIntervalSet& unacked_frame_headers_offsets( QuicSpdyStream* stream); + static bool use_datagram_contexts(QuicSpdyStream* stream); }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc b/gquiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc index df40bcd2..5a082391 100644 --- a/gquiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc +++ b/gquiche/quic/test_tools/quic_stream_sequencer_buffer_peer.cc @@ -46,8 +46,7 @@ bool QuicStreamSequencerBufferPeer::IsBlockArrayEmpty() { return true; } - size_t count = buffer_->allocate_blocks_on_demand_ ? current_blocks_count() - : max_blocks_count(); + size_t count = current_blocks_count(); for (size_t i = 0; i < count; i++) { if (buffer_->blocks_[i] != nullptr) { return false; diff --git a/gquiche/quic/test_tools/quic_test_backend.cc b/gquiche/quic/test_tools/quic_test_backend.cc index 67a44e9a..0cebf0d4 100644 --- a/gquiche/quic/test_tools/quic_test_backend.cc +++ b/gquiche/quic/test_tools/quic_test_backend.cc @@ -7,12 +7,14 @@ #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "gquiche/quic/core/quic_buffer_allocator.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/core/web_transport_interface.h" #include "gquiche/quic/platform/api/quic_mem_slice.h" +#include "gquiche/quic/test_tools/web_transport_resets_backend.h" #include "gquiche/quic/tools/web_transport_test_visitors.h" namespace quic { @@ -20,89 +22,45 @@ namespace test { namespace { -class EchoWebTransportServer : public WebTransportVisitor { +// SessionCloseVisitor implements the "/session-close" endpoint. If the client +// sends a unidirectional stream of format "code message" to this endpoint, it +// will close the session with the corresponding error code and error message. +// For instance, sending "42 test error" will cause it to be closed with code 42 +// and message "test error". +class SessionCloseVisitor : public WebTransportVisitor { public: - EchoWebTransportServer(WebTransportSession* session) : session_(session) {} + SessionCloseVisitor(WebTransportSession* session) : session_(session) {} - void OnSessionReady() override { - if (session_->CanOpenNextOutgoingBidirectionalStream()) { - OnCanCreateNewOutgoingBidirectionalStream(); - } - } - - void OnIncomingBidirectionalStreamAvailable() override { - while (true) { - WebTransportStream* stream = - session_->AcceptIncomingBidirectionalStream(); - if (stream == nullptr) { - return; - } - QUIC_DVLOG(1) << "EchoWebTransportServer received a bidirectional stream " - << stream->GetStreamId(); - stream->SetVisitor( - std::make_unique(stream)); - stream->visitor()->OnCanRead(); - } - } + void OnSessionReady(const spdy::SpdyHeaderBlock& /*headers*/) override {} + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} + void OnIncomingBidirectionalStreamAvailable() override {} void OnIncomingUnidirectionalStreamAvailable() override { - while (true) { - WebTransportStream* stream = - session_->AcceptIncomingUnidirectionalStream(); - if (stream == nullptr) { - return; - } - QUIC_DVLOG(1) - << "EchoWebTransportServer received a unidirectional stream"; - stream->SetVisitor( - std::make_unique( - stream, [this](const std::string& data) { - streams_to_echo_back_.push_back(data); - TrySendingUnidirectionalStreams(); - })); - stream->visitor()->OnCanRead(); + WebTransportStream* stream = session_->AcceptIncomingUnidirectionalStream(); + if (stream == nullptr) { + return; } + stream->SetVisitor( + std::make_unique( + stream, [this](const std::string& data) { + std::pair parsed = + absl::StrSplit(data, absl::MaxSplits(' ', 1)); + WebTransportSessionError error_code = 0; + bool success = absl::SimpleAtoi(parsed.first, &error_code); + QUICHE_DCHECK(success) << data; + session_->CloseSession(error_code, parsed.second); + })); + stream->visitor()->OnCanRead(); } - void OnDatagramReceived(absl::string_view datagram) override { - auto buffer = MakeUniqueBuffer(&allocator_, datagram.size()); - memcpy(buffer.get(), datagram.data(), datagram.size()); - QuicMemSlice slice(std::move(buffer), datagram.size()); - session_->SendOrQueueDatagram(std::move(slice)); - } + void OnDatagramReceived(absl::string_view /*datagram*/) override {} - void OnCanCreateNewOutgoingBidirectionalStream() override { - if (!echo_stream_opened_) { - WebTransportStream* stream = session_->OpenOutgoingBidirectionalStream(); - stream->SetVisitor( - std::make_unique(stream)); - echo_stream_opened_ = true; - } - } - void OnCanCreateNewOutgoingUnidirectionalStream() override { - TrySendingUnidirectionalStreams(); - } - - void TrySendingUnidirectionalStreams() { - while (!streams_to_echo_back_.empty() && - session_->CanOpenNextOutgoingUnidirectionalStream()) { - QUIC_DVLOG(1) - << "EchoWebTransportServer echoed a unidirectional stream back"; - WebTransportStream* stream = session_->OpenOutgoingUnidirectionalStream(); - stream->SetVisitor( - std::make_unique( - stream, streams_to_echo_back_.front())); - streams_to_echo_back_.pop_front(); - stream->visitor()->OnCanWrite(); - } - } + void OnCanCreateNewOutgoingBidirectionalStream() override {} + void OnCanCreateNewOutgoingUnidirectionalStream() override {} private: - WebTransportSession* session_; - SimpleBufferAllocator allocator_; - bool echo_stream_opened_ = false; - - QuicCircularDeque streams_to_echo_back_; + WebTransportSession* session_; // Not owned. }; } // namespace @@ -123,10 +81,36 @@ QuicTestBackend::ProcessWebTransportRequest( return response; } absl::string_view path = path_it->second; - if (path == "/echo") { + // Match any "/echo.*" pass, e.g. "/echo_foobar" + if (absl::StartsWith(path, "/echo")) { + WebTransportResponse response; + response.response_headers[":status"] = "200"; + // Add response headers if the paramer has "set-header=XXX:YYY" query. + GURL url = GURL(absl::StrCat("https://localhost", path)); + const std::vector& params = absl::StrSplit(url.query(), '&'); + for (const auto& param : params) { + absl::string_view param_view = param; + if (absl::ConsumePrefix(¶m_view, "set-header=")) { + const std::vector header_value = + absl::StrSplit(param_view, ':'); + if (header_value.size() == 2 && + !absl::StartsWith(header_value[0], ":")) { + response.response_headers[header_value[0]] = header_value[1]; + } + } + } + + response.visitor = + std::make_unique(session); + return response; + } + if (path == "/resets") { + return WebTransportResetsBackend(request_headers, session); + } + if (path == "/session-close") { WebTransportResponse response; response.response_headers[":status"] = "200"; - response.visitor = std::make_unique(session); + response.visitor = std::make_unique(session); return response; } diff --git a/gquiche/quic/test_tools/quic_test_backend.h b/gquiche/quic/test_tools/quic_test_backend.h index 5a96d901..97956381 100644 --- a/gquiche/quic/test_tools/quic_test_backend.h +++ b/gquiche/quic/test_tools/quic_test_backend.h @@ -6,6 +6,7 @@ #define QUICHE_QUIC_TEST_TOOLS_QUIC_TEST_BACKEND_H_ #include "gquiche/quic/tools/quic_memory_cache_backend.h" +#include "gquiche/common/platform/api/quiche_logging.h" namespace quic { namespace test { @@ -21,11 +22,26 @@ class QuicTestBackend : public QuicMemoryCacheBackend { bool SupportsWebTransport() override { return enable_webtransport_; } void set_enable_webtransport(bool enable_webtransport) { + QUICHE_DCHECK(!enable_webtransport || enable_extended_connect_); enable_webtransport_ = enable_webtransport; } + bool UsesDatagramContexts() override { return use_datagram_contexts_; } + + void set_use_datagram_contexts(bool use_datagram_contexts) { + use_datagram_contexts_ = use_datagram_contexts; + } + + bool SupportsExtendedConnect() override { return enable_extended_connect_; } + + void set_enable_extended_connect(bool enable_extended_connect) { + enable_extended_connect_ = enable_extended_connect; + } + private: bool enable_webtransport_ = false; + bool use_datagram_contexts_ = false; + bool enable_extended_connect_ = true; }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_test_client.cc b/gquiche/quic/test_tools/quic_test_client.cc index 08c1119f..99e641ae 100644 --- a/gquiche/quic/test_tools/quic_test_client.cc +++ b/gquiche/quic/test_tools/quic_test_client.cc @@ -28,7 +28,7 @@ #include "gquiche/quic/test_tools/quic_spdy_stream_peer.h" #include "gquiche/quic/test_tools/quic_test_utils.h" #include "gquiche/quic/tools/quic_url.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace test { @@ -102,6 +102,7 @@ class RecordingProofVerifier : public ProofVerifier { return QUIC_FAILURE; } + // Parse the cert into an X509 structure. const uint8_t* data; data = reinterpret_cast(certs[0].data()); bssl::UniquePtr cert(d2i_X509(nullptr, &data, certs[0].size())); @@ -109,15 +110,28 @@ class RecordingProofVerifier : public ProofVerifier { return QUIC_FAILURE; } - static const unsigned kMaxCommonNameLength = 256; - char buf[kMaxCommonNameLength]; - X509_NAME* subject_name = X509_get_subject_name(cert.get()); - if (X509_NAME_get_text_by_NID(subject_name, NID_commonName, buf, - sizeof(buf)) <= 0) { + // Extract the CN field + X509_NAME* subject = X509_get_subject_name(cert.get()); + const int index = X509_NAME_get_index_by_NID(subject, NID_commonName, -1); + if (index < 0) { + return QUIC_FAILURE; + } + ASN1_STRING* name_data = + X509_NAME_ENTRY_get_data(X509_NAME_get_entry(subject, index)); + if (name_data == nullptr) { return QUIC_FAILURE; } - common_name_ = buf; + // Convert the CN to UTF8, in case the cert represents it in a different + // format. + unsigned char* buf = nullptr; + const int len = ASN1_STRING_to_UTF8(&buf, name_data); + if (len <= 0) { + return QUIC_FAILURE; + } + bssl::UniquePtr deleter(buf); + + common_name_.assign(reinterpret_cast(buf), len); cert_sct_ = cert_sct; return QUIC_SUCCESS; } @@ -604,7 +618,7 @@ QuicSpdyClientStream* QuicTestClient::GetOrCreateStream() { return latest_created_stream_; } -QuicErrorCode QuicTestClient::connection_error() { +QuicErrorCode QuicTestClient::connection_error() const { return client()->connection_error(); } @@ -618,16 +632,12 @@ const std::string& QuicTestClient::cert_sct() const { ->cert_sct(); } -QuicTagValueMap QuicTestClient::GetServerConfig() const { +const QuicTagValueMap& QuicTestClient::GetServerConfig() const { QuicCryptoClientConfig* config = client_->crypto_config(); - QuicCryptoClientConfig::CachedState* state = + const QuicCryptoClientConfig::CachedState* state = config->LookupOrCreate(client_->server_id()); const CryptoHandshakeMessage* handshake_msg = state->GetServerConfig(); - if (handshake_msg != nullptr) { - return handshake_msg->tag_value_map(); - } else { - return QuicTagValueMap(); - } + return handshake_msg->tag_value_map(); } bool QuicTestClient::connected() const { @@ -786,7 +796,7 @@ void QuicTestClient::OnClose(QuicSpdyStream* stream) { // written. client()->OnClose(stream); ++num_responses_; - if (!QuicContainsKey(open_streams_, stream->id())) { + if (open_streams_.find(stream->id()) == open_streams_.end()) { return; } if (latest_created_stream_ == stream) { diff --git a/gquiche/quic/test_tools/quic_test_client.h b/gquiche/quic/test_tools/quic_test_client.h index e5744525..7607c288 100644 --- a/gquiche/quic/test_tools/quic_test_client.h +++ b/gquiche/quic/test_tools/quic_test_client.h @@ -14,11 +14,10 @@ #include "gquiche/quic/core/quic_framer.h" #include "gquiche/quic/core/quic_packet_creator.h" #include "gquiche/quic/core/quic_packets.h" -#include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_epoll.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/tools/quic_client.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -269,7 +268,7 @@ class QuicTestClient : public QuicSpdyStream::Visitor, QuicReferenceCountedPointer ack_listener); QuicRstStreamErrorCode stream_error() { return stream_error_; } - QuicErrorCode connection_error(); + QuicErrorCode connection_error() const; MockableQuicClient* client() { return client_.get(); } const MockableQuicClient* client() const { return client_.get(); } @@ -282,8 +281,8 @@ class QuicTestClient : public QuicSpdyStream::Visitor, // or the empty std::string if no signed timestamp was presented. const std::string& cert_sct() const; - // Get the server config map. - QuicTagValueMap GetServerConfig() const; + // Get the server config map. Server config must exist. + const QuicTagValueMap& GetServerConfig() const; void set_auto_reconnect(bool reconnect) { auto_reconnect_ = reconnect; } @@ -401,7 +400,8 @@ class QuicTestClient : public QuicSpdyStream::Visitor, QuicSpdyClientStream* latest_created_stream_; std::map open_streams_; // Received responses of closed streams. - QuicLinkedHashMap closed_stream_states_; + quiche::QuicheLinkedHashMap + closed_stream_states_; QuicRstStreamErrorCode stream_error_; diff --git a/gquiche/quic/test_tools/quic_test_server.cc b/gquiche/quic/test_tools/quic_test_server.cc index 4d16b931..1de244a3 100644 --- a/gquiche/quic/test_tools/quic_test_server.cc +++ b/gquiche/quic/test_tools/quic_test_server.cc @@ -94,25 +94,25 @@ class QuicTestDispatcher : public QuicSimpleDispatcher { crypto_stream_factory_(nullptr) {} std::unique_ptr CreateQuicSession( - QuicConnectionId id, - const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view alpn, + QuicConnectionId id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view /*alpn*/, const ParsedQuicVersion& version, - absl::string_view sni) override { + const ParsedClientHello& /*parsed_chlo*/) override { QuicReaderMutexLock lock(&factory_lock_); - if (session_factory_ == nullptr && stream_factory_ == nullptr && - crypto_stream_factory_ == nullptr) { - return QuicSimpleDispatcher::CreateQuicSession( - id, self_address, peer_address, alpn, version, sni); - } + // The QuicServerSessionBase takes ownership of |connection| below. QuicConnection* connection = new QuicConnection( id, self_address, peer_address, helper(), alarm_factory(), writer(), /* owns_writer= */ false, Perspective::IS_SERVER, ParsedQuicVersionVector{version}); std::unique_ptr session; - if (stream_factory_ != nullptr || crypto_stream_factory_ != nullptr) { + if (session_factory_ == nullptr && stream_factory_ == nullptr && + crypto_stream_factory_ == nullptr) { + session = std::make_unique( + config(), GetSupportedVersions(), connection, this, session_helper(), + crypto_config(), compressed_certs_cache(), server_backend()); + } else if (stream_factory_ != nullptr || + crypto_stream_factory_ != nullptr) { session = std::make_unique( config(), GetSupportedVersions(), connection, this, session_helper(), crypto_config(), compressed_certs_cache(), stream_factory_, @@ -122,6 +122,14 @@ class QuicTestDispatcher : public QuicSimpleDispatcher { config(), connection, this, session_helper(), crypto_config(), compressed_certs_cache(), server_backend()); } + if (VersionUsesHttp3(version.transport_version) && + GetQuicReloadableFlag(quic_verify_request_headers_2)) { + QUICHE_DCHECK(session->allow_extended_connect()); + // Do not allow extended CONNECT request if the backend doesn't support + // it. + session->set_allow_extended_connect( + server_backend()->SupportsExtendedConnect()); + } session->Initialize(); return session; } diff --git a/gquiche/quic/test_tools/quic_test_utils.cc b/gquiche/quic/test_tools/quic_test_utils.cc index 38afc2e5..ce4ec93d 100644 --- a/gquiche/quic/test_tools/quic_test_utils.cc +++ b/gquiche/quic/test_tools/quic_test_utils.cc @@ -88,9 +88,7 @@ std::vector CreateStatelessResetTokenForTest() { sizeof(kStatelessResetTokenDataForTest)); } -std::string TestHostname() { - return "test.example.org"; -} +std::string TestHostname() { return "test.example.com"; } QuicServerId TestServerId() { return QuicServerId(TestHostname(), kTestPort); @@ -191,7 +189,11 @@ std::unique_ptr BuildUnsizedDataPacket( EncryptionLevel level = HeaderToEncryptionLevel(header); size_t length = framer->BuildDataPacket(header, frames, buffer, packet_size, level); - QUICHE_DCHECK_NE(0u, length); + + if (length == 0) { + delete[] buffer; + return nullptr; + } // Re-construct the data packet with data ownership. return std::make_unique( buffer, length, /* owns_buffer */ true, @@ -1308,25 +1310,13 @@ StreamType DetermineStreamType(QuicStreamId id, : default_type; } -QuicMemSliceSpan MakeSpan(QuicBufferAllocator* allocator, - absl::string_view message_data, - QuicMemSliceStorage* storage) { - if (message_data.length() == 0) { - *storage = - QuicMemSliceStorage(nullptr, 0, allocator, kMaxOutgoingPacketSize); - return storage->ToSpan(); +QuicMemSlice MemSliceFromString(absl::string_view data) { + if (data.empty()) { + return QuicMemSlice(); } - struct iovec iov = {const_cast(message_data.data()), - message_data.length()}; - *storage = QuicMemSliceStorage(&iov, 1, allocator, kMaxOutgoingPacketSize); - return storage->ToSpan(); -} -QuicMemSlice MemSliceFromString(absl::string_view data) { static SimpleBufferAllocator* allocator = new SimpleBufferAllocator(); - QuicUniqueBufferPtr buffer = MakeUniqueBuffer(allocator, data.size()); - memcpy(buffer.get(), data.data(), data.size()); - return QuicMemSlice(std::move(buffer), data.size()); + return QuicMemSlice(QuicBuffer::Copy(allocator, data)); } bool TaggingEncrypter::EncryptPacket(uint64_t /*packet_number*/, @@ -1591,18 +1581,18 @@ bool ParseClientVersionNegotiationProbePacket( QuicEncryptedPacket encrypted_packet(packet_bytes, packet_length); PacketHeaderFormat format; QuicLongHeaderType long_packet_type; - bool version_present, has_length_prefix, retry_token_present; + bool version_present, has_length_prefix; QuicVersionLabel version_label; ParsedQuicVersion parsed_version = ParsedQuicVersion::Unsupported(); QuicConnectionId destination_connection_id, source_connection_id; - absl::string_view retry_token; + absl::optional retry_token; std::string detailed_error; QuicErrorCode error = QuicFramer::ParsePublicHeaderDispatcher( encrypted_packet, /*expected_destination_connection_id_length=*/0, &format, &long_packet_type, &version_present, &has_length_prefix, &version_label, &parsed_version, &destination_connection_id, &source_connection_id, - &retry_token_present, &retry_token, &detailed_error); + &retry_token, &detailed_error); if (error != QUIC_NO_ERROR) { QUIC_BUG(quic_bug_10256_9) << "Failed to parse packet: " << detailed_error; return false; diff --git a/gquiche/quic/test_tools/quic_test_utils.h b/gquiche/quic/test_tools/quic_test_utils.h index 6dae2624..bd1d977f 100644 --- a/gquiche/quic/test_tools/quic_test_utils.h +++ b/gquiche/quic/test_tools/quic_test_utils.h @@ -24,6 +24,7 @@ #include "gquiche/quic/core/http/quic_spdy_session.h" #include "gquiche/quic/core/quic_connection.h" #include "gquiche/quic/core/quic_connection_id.h" +#include "gquiche/quic/core/quic_error_codes.h" #include "gquiche/quic/core/quic_framer.h" #include "gquiche/quic/core/quic_packet_writer.h" #include "gquiche/quic/core/quic_path_validator.h" @@ -103,26 +104,17 @@ void DisableQuicVersionsWithTls(); // constructed packet, the framer must be set to use NullDecrypter. QuicEncryptedPacket* ConstructEncryptedPacket( QuicConnectionId destination_connection_id, - QuicConnectionId source_connection_id, - bool version_flag, - bool reset_flag, - uint64_t packet_number, - const std::string& data, - bool full_padding, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, bool full_padding, QuicConnectionIdIncluded destination_connection_id_included, QuicConnectionIdIncluded source_connection_id_included, QuicPacketNumberLength packet_number_length, - ParsedQuicVersionVector* versions, - Perspective perspective); + ParsedQuicVersionVector* versions, Perspective perspective); QuicEncryptedPacket* ConstructEncryptedPacket( QuicConnectionId destination_connection_id, - QuicConnectionId source_connection_id, - bool version_flag, - bool reset_flag, - uint64_t packet_number, - const std::string& data, - bool full_padding, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, bool full_padding, QuicConnectionIdIncluded destination_connection_id_included, QuicConnectionIdIncluded source_connection_id_included, QuicPacketNumberLength packet_number_length, @@ -134,11 +126,8 @@ QuicEncryptedPacket* ConstructEncryptedPacket( // constructed packet, the framer must be set to use NullDecrypter. QuicEncryptedPacket* ConstructEncryptedPacket( QuicConnectionId destination_connection_id, - QuicConnectionId source_connection_id, - bool version_flag, - bool reset_flag, - uint64_t packet_number, - const std::string& data, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, QuicConnectionIdIncluded destination_connection_id_included, QuicConnectionIdIncluded source_connection_id_included, QuicPacketNumberLength packet_number_length, @@ -147,11 +136,8 @@ QuicEncryptedPacket* ConstructEncryptedPacket( // This form assumes |versions| == nullptr. QuicEncryptedPacket* ConstructEncryptedPacket( QuicConnectionId destination_connection_id, - QuicConnectionId source_connection_id, - bool version_flag, - bool reset_flag, - uint64_t packet_number, - const std::string& data, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, QuicConnectionIdIncluded destination_connection_id_included, QuicConnectionIdIncluded source_connection_id_included, QuicPacketNumberLength packet_number_length); @@ -161,11 +147,8 @@ QuicEncryptedPacket* ConstructEncryptedPacket( // |versions| == nullptr. QuicEncryptedPacket* ConstructEncryptedPacket( QuicConnectionId destination_connection_id, - QuicConnectionId source_connection_id, - bool version_flag, - bool reset_flag, - uint64_t packet_number, - const std::string& data); + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data); // Creates a client-to-server ZERO-RTT packet that will fail to decrypt. std::unique_ptr GetUndecryptableEarlyPacket( @@ -175,8 +158,7 @@ std::unique_ptr GetUndecryptableEarlyPacket( // Constructs a received packet for testing. The caller must take ownership // of the returned pointer. QuicReceivedPacket* ConstructReceivedPacket( - const QuicEncryptedPacket& encrypted_packet, - QuicTime receipt_time); + const QuicEncryptedPacket& encrypted_packet, QuicTime receipt_time); // Create an encrypted packet for testing whose data portion erroneous. // The specific way the data portion is erroneous is not specified, but @@ -185,15 +167,11 @@ QuicReceivedPacket* ConstructReceivedPacket( // constructed packet, the framer must be set to use NullDecrypter. QuicEncryptedPacket* ConstructMisFramedEncryptedPacket( QuicConnectionId destination_connection_id, - QuicConnectionId source_connection_id, - bool version_flag, - bool reset_flag, - uint64_t packet_number, - const std::string& data, + QuicConnectionId source_connection_id, bool version_flag, bool reset_flag, + uint64_t packet_number, const std::string& data, QuicConnectionIdIncluded destination_connection_id_included, QuicConnectionIdIncluded source_connection_id_included, - QuicPacketNumberLength packet_number_length, - ParsedQuicVersion version, + QuicPacketNumberLength packet_number_length, ParsedQuicVersion version, Perspective perspective); // Returns QuicConfig set to default values. @@ -228,8 +206,7 @@ QuicAckFrame MakeAckFrameWithAckBlocks(size_t num_ack_blocks, // Testing convenice method to construct a QuicAckFrame with |largest_acked|, // ack blocks of width 1 packet and |gap_size|. -QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, - size_t max_num_gaps, +QuicAckFrame MakeAckFrameWithGaps(uint64_t gap_size, size_t max_num_gaps, uint64_t largest_acked); // Returns the encryption level that corresponds to the header type in @@ -241,15 +218,12 @@ EncryptionLevel HeaderToEncryptionLevel(const QuicPacketHeader& header); // is populated with the fields in |header| and |frames|, or is nullptr if the // packet could not be created. std::unique_ptr BuildUnsizedDataPacket( - QuicFramer* framer, - const QuicPacketHeader& header, + QuicFramer* framer, const QuicPacketHeader& header, const QuicFrames& frames); // Returns a QuicPacket that is owned by the caller, and of size |packet_size|. std::unique_ptr BuildUnsizedDataPacket( - QuicFramer* framer, - const QuicPacketHeader& header, - const QuicFrames& frames, - size_t packet_size); + QuicFramer* framer, const QuicPacketHeader& header, + const QuicFrames& frames, size_t packet_size); // Compute SHA-1 hash of the supplied std::string. std::string Sha1Hash(absl::string_view data); @@ -282,7 +256,7 @@ class SimpleRandom : public QuicRandom { private: uint8_t buffer_[4096]; - size_t buffer_offset_; + size_t buffer_offset_ = 0; uint8_t key_[32]; void FillBuffer(); @@ -297,21 +271,14 @@ class MockFramerVisitor : public QuicFramerVisitorInterface { MOCK_METHOD(void, OnError, (QuicFramer*), (override)); // The constructor sets this up to return false by default. - MOCK_METHOD(bool, - OnProtocolVersionMismatch, - (ParsedQuicVersion version), + MOCK_METHOD(bool, OnProtocolVersionMismatch, (ParsedQuicVersion version), (override)); MOCK_METHOD(void, OnPacket, (), (override)); - MOCK_METHOD(void, - OnPublicResetPacket, - (const QuicPublicResetPacket& header), + MOCK_METHOD(void, OnPublicResetPacket, (const QuicPublicResetPacket& header), (override)); - MOCK_METHOD(void, - OnVersionNegotiationPacket, - (const QuicVersionNegotiationPacket& packet), - (override)); - MOCK_METHOD(void, - OnRetryPacket, + MOCK_METHOD(void, OnVersionNegotiationPacket, + (const QuicVersionNegotiationPacket& packet), (override)); + MOCK_METHOD(void, OnRetryPacket, (QuicConnectionId original_connection_id, QuicConnectionId new_connection_id, absl::string_view retry_token, @@ -319,133 +286,75 @@ class MockFramerVisitor : public QuicFramerVisitorInterface { absl::string_view retry_without_tag), (override)); // The constructor sets this up to return true by default. - MOCK_METHOD(bool, - OnUnauthenticatedHeader, - (const QuicPacketHeader& header), + MOCK_METHOD(bool, OnUnauthenticatedHeader, (const QuicPacketHeader& header), (override)); // The constructor sets this up to return true by default. - MOCK_METHOD(bool, - OnUnauthenticatedPublicHeader, - (const QuicPacketHeader& header), - (override)); - MOCK_METHOD(void, - OnDecryptedPacket, - (size_t length, EncryptionLevel level), + MOCK_METHOD(bool, OnUnauthenticatedPublicHeader, + (const QuicPacketHeader& header), (override)); + MOCK_METHOD(void, OnDecryptedPacket, (size_t length, EncryptionLevel level), (override)); - MOCK_METHOD(bool, - OnPacketHeader, - (const QuicPacketHeader& header), + MOCK_METHOD(bool, OnPacketHeader, (const QuicPacketHeader& header), (override)); - MOCK_METHOD(void, - OnCoalescedPacket, - (const QuicEncryptedPacket& packet), + MOCK_METHOD(void, OnCoalescedPacket, (const QuicEncryptedPacket& packet), (override)); - MOCK_METHOD(void, - OnUndecryptablePacket, + MOCK_METHOD(void, OnUndecryptablePacket, (const QuicEncryptedPacket& packet, - EncryptionLevel decryption_level, - bool has_decryption_key), + EncryptionLevel decryption_level, bool has_decryption_key), (override)); MOCK_METHOD(bool, OnStreamFrame, (const QuicStreamFrame& frame), (override)); MOCK_METHOD(bool, OnCryptoFrame, (const QuicCryptoFrame& frame), (override)); - MOCK_METHOD(bool, - OnAckFrameStart, - (QuicPacketNumber, QuicTime::Delta), + MOCK_METHOD(bool, OnAckFrameStart, (QuicPacketNumber, QuicTime::Delta), (override)); - MOCK_METHOD(bool, - OnAckRange, - (QuicPacketNumber, QuicPacketNumber), + MOCK_METHOD(bool, OnAckRange, (QuicPacketNumber, QuicPacketNumber), (override)); MOCK_METHOD(bool, OnAckTimestamp, (QuicPacketNumber, QuicTime), (override)); MOCK_METHOD(bool, OnAckFrameEnd, (QuicPacketNumber), (override)); - MOCK_METHOD(bool, - OnStopWaitingFrame, - (const QuicStopWaitingFrame& frame), + MOCK_METHOD(bool, OnStopWaitingFrame, (const QuicStopWaitingFrame& frame), (override)); - MOCK_METHOD(bool, - OnPaddingFrame, - (const QuicPaddingFrame& frame), + MOCK_METHOD(bool, OnPaddingFrame, (const QuicPaddingFrame& frame), (override)); MOCK_METHOD(bool, OnPingFrame, (const QuicPingFrame& frame), (override)); - MOCK_METHOD(bool, - OnRstStreamFrame, - (const QuicRstStreamFrame& frame), - (override)); - MOCK_METHOD(bool, - OnConnectionCloseFrame, - (const QuicConnectionCloseFrame& frame), - (override)); - MOCK_METHOD(bool, - OnNewConnectionIdFrame, - (const QuicNewConnectionIdFrame& frame), - (override)); - MOCK_METHOD(bool, - OnRetireConnectionIdFrame, - (const QuicRetireConnectionIdFrame& frame), + MOCK_METHOD(bool, OnRstStreamFrame, (const QuicRstStreamFrame& frame), (override)); - MOCK_METHOD(bool, - OnNewTokenFrame, - (const QuicNewTokenFrame& frame), + MOCK_METHOD(bool, OnConnectionCloseFrame, + (const QuicConnectionCloseFrame& frame), (override)); + MOCK_METHOD(bool, OnNewConnectionIdFrame, + (const QuicNewConnectionIdFrame& frame), (override)); + MOCK_METHOD(bool, OnRetireConnectionIdFrame, + (const QuicRetireConnectionIdFrame& frame), (override)); + MOCK_METHOD(bool, OnNewTokenFrame, (const QuicNewTokenFrame& frame), (override)); - MOCK_METHOD(bool, - OnStopSendingFrame, - (const QuicStopSendingFrame& frame), + MOCK_METHOD(bool, OnStopSendingFrame, (const QuicStopSendingFrame& frame), (override)); - MOCK_METHOD(bool, - OnPathChallengeFrame, - (const QuicPathChallengeFrame& frame), + MOCK_METHOD(bool, OnPathChallengeFrame, (const QuicPathChallengeFrame& frame), (override)); - MOCK_METHOD(bool, - OnPathResponseFrame, - (const QuicPathResponseFrame& frame), + MOCK_METHOD(bool, OnPathResponseFrame, (const QuicPathResponseFrame& frame), (override)); MOCK_METHOD(bool, OnGoAwayFrame, (const QuicGoAwayFrame& frame), (override)); - MOCK_METHOD(bool, - OnMaxStreamsFrame, - (const QuicMaxStreamsFrame& frame), + MOCK_METHOD(bool, OnMaxStreamsFrame, (const QuicMaxStreamsFrame& frame), (override)); - MOCK_METHOD(bool, - OnStreamsBlockedFrame, - (const QuicStreamsBlockedFrame& frame), + MOCK_METHOD(bool, OnStreamsBlockedFrame, + (const QuicStreamsBlockedFrame& frame), (override)); + MOCK_METHOD(bool, OnWindowUpdateFrame, (const QuicWindowUpdateFrame& frame), (override)); - MOCK_METHOD(bool, - OnWindowUpdateFrame, - (const QuicWindowUpdateFrame& frame), + MOCK_METHOD(bool, OnBlockedFrame, (const QuicBlockedFrame& frame), (override)); - MOCK_METHOD(bool, - OnBlockedFrame, - (const QuicBlockedFrame& frame), + MOCK_METHOD(bool, OnMessageFrame, (const QuicMessageFrame& frame), (override)); - MOCK_METHOD(bool, - OnMessageFrame, - (const QuicMessageFrame& frame), + MOCK_METHOD(bool, OnHandshakeDoneFrame, (const QuicHandshakeDoneFrame& frame), (override)); - MOCK_METHOD(bool, - OnHandshakeDoneFrame, - (const QuicHandshakeDoneFrame& frame), - (override)); - MOCK_METHOD(bool, - OnAckFrequencyFrame, - (const QuicAckFrequencyFrame& frame), + MOCK_METHOD(bool, OnAckFrequencyFrame, (const QuicAckFrequencyFrame& frame), (override)); MOCK_METHOD(void, OnPacketComplete, (), (override)); - MOCK_METHOD(bool, - IsValidStatelessResetToken, - (const StatelessResetToken&), + MOCK_METHOD(bool, IsValidStatelessResetToken, (const StatelessResetToken&), (const, override)); - MOCK_METHOD(void, - OnAuthenticatedIetfStatelessResetPacket, - (const QuicIetfStatelessResetPacket&), - (override)); + MOCK_METHOD(void, OnAuthenticatedIetfStatelessResetPacket, + (const QuicIetfStatelessResetPacket&), (override)); MOCK_METHOD(void, OnKeyUpdate, (KeyUpdateReason), (override)); MOCK_METHOD(void, OnDecryptedFirstPacketInKeyPhase, (), (override)); MOCK_METHOD(std::unique_ptr, - AdvanceKeysAndCreateCurrentOneRttDecrypter, - (), - (override)); - MOCK_METHOD(std::unique_ptr, - CreateCurrentOneRttEncrypter, - (), + AdvanceKeysAndCreateCurrentOneRttDecrypter, (), (override)); + MOCK_METHOD(std::unique_ptr, CreateCurrentOneRttEncrypter, (), (override)); }; @@ -529,101 +438,70 @@ class MockQuicConnectionVisitor : public QuicConnectionVisitorInterface { MOCK_METHOD(void, OnStreamFrame, (const QuicStreamFrame& frame), (override)); MOCK_METHOD(void, OnCryptoFrame, (const QuicCryptoFrame& frame), (override)); - MOCK_METHOD(void, - OnWindowUpdateFrame, - (const QuicWindowUpdateFrame& frame), + MOCK_METHOD(void, OnWindowUpdateFrame, (const QuicWindowUpdateFrame& frame), (override)); - MOCK_METHOD(void, - OnBlockedFrame, - (const QuicBlockedFrame& frame), + MOCK_METHOD(void, OnBlockedFrame, (const QuicBlockedFrame& frame), (override)); MOCK_METHOD(void, OnRstStream, (const QuicRstStreamFrame& frame), (override)); MOCK_METHOD(void, OnGoAway, (const QuicGoAwayFrame& frame), (override)); MOCK_METHOD(void, OnMessageReceived, (absl::string_view message), (override)); MOCK_METHOD(void, OnHandshakeDoneReceived, (), (override)); MOCK_METHOD(void, OnNewTokenReceived, (absl::string_view token), (override)); - MOCK_METHOD(void, - OnConnectionClosed, + MOCK_METHOD(void, OnConnectionClosed, (const QuicConnectionCloseFrame& frame, ConnectionCloseSource source), (override)); MOCK_METHOD(void, OnWriteBlocked, (), (override)); MOCK_METHOD(void, OnCanWrite, (), (override)); MOCK_METHOD(bool, SendProbingData, (), (override)); - MOCK_METHOD(bool, - ValidateStatelessReset, + MOCK_METHOD(bool, ValidateStatelessReset, (const quic::QuicSocketAddress&, const quic::QuicSocketAddress&), (override)); MOCK_METHOD(void, OnCongestionWindowChange, (QuicTime now), (override)); - MOCK_METHOD(void, - OnConnectionMigration, - (AddressChangeType type), + MOCK_METHOD(void, OnConnectionMigration, (AddressChangeType type), (override)); MOCK_METHOD(void, OnPathDegrading, (), (override)); MOCK_METHOD(void, OnForwardProgressMadeAfterPathDegrading, (), (override)); MOCK_METHOD(bool, WillingAndAbleToWrite, (), (const, override)); MOCK_METHOD(bool, ShouldKeepConnectionAlive, (), (const, override)); MOCK_METHOD(std::string, GetStreamsInfoForLogging, (), (const, override)); - MOCK_METHOD(void, - OnSuccessfulVersionNegotiation, - (const ParsedQuicVersion& version), - (override)); - MOCK_METHOD(void, - OnPacketReceived, + MOCK_METHOD(void, OnSuccessfulVersionNegotiation, + (const ParsedQuicVersion& version), (override)); + MOCK_METHOD(void, OnPacketReceived, (const QuicSocketAddress& self_address, const QuicSocketAddress& peer_address, bool is_connectivity_probe), (override)); MOCK_METHOD(void, OnAckNeedsRetransmittableFrame, (), (override)); - MOCK_METHOD(void, - SendAckFrequency, - (const QuicAckFrequencyFrame& frame), - (override)); - MOCK_METHOD(void, - SendNewConnectionId, - (const QuicNewConnectionIdFrame& frame), - (override)); - MOCK_METHOD(void, - SendRetireConnectionId, - (uint64_t sequence_number), + MOCK_METHOD(void, SendAckFrequency, (const QuicAckFrequencyFrame& frame), (override)); - MOCK_METHOD(void, - OnServerConnectionIdIssued, - (const QuicConnectionId& server_connection_id), - (override)); - MOCK_METHOD(void, - OnServerConnectionIdRetired, - (const QuicConnectionId& server_connection_id), + MOCK_METHOD(void, SendNewConnectionId, + (const QuicNewConnectionIdFrame& frame), (override)); + MOCK_METHOD(void, SendRetireConnectionId, (uint64_t sequence_number), (override)); + MOCK_METHOD(void, OnServerConnectionIdIssued, + (const QuicConnectionId& server_connection_id), (override)); + MOCK_METHOD(void, OnServerConnectionIdRetired, + (const QuicConnectionId& server_connection_id), (override)); MOCK_METHOD(bool, AllowSelfAddressChange, (), (const, override)); MOCK_METHOD(HandshakeState, GetHandshakeState, (), (const, override)); - MOCK_METHOD(bool, - OnMaxStreamsFrame, - (const QuicMaxStreamsFrame& frame), - (override)); - MOCK_METHOD(bool, - OnStreamsBlockedFrame, - (const QuicStreamsBlockedFrame& frame), + MOCK_METHOD(bool, OnMaxStreamsFrame, (const QuicMaxStreamsFrame& frame), (override)); - MOCK_METHOD(void, - OnStopSendingFrame, - (const QuicStopSendingFrame& frame), + MOCK_METHOD(bool, OnStreamsBlockedFrame, + (const QuicStreamsBlockedFrame& frame), (override)); + MOCK_METHOD(void, OnStopSendingFrame, (const QuicStopSendingFrame& frame), (override)); MOCK_METHOD(void, OnPacketDecrypted, (EncryptionLevel), (override)); MOCK_METHOD(void, OnOneRttPacketAcknowledged, (), (override)); MOCK_METHOD(void, OnHandshakePacketSent, (), (override)); MOCK_METHOD(void, OnKeyUpdate, (KeyUpdateReason), (override)); MOCK_METHOD(std::unique_ptr, - AdvanceKeysAndCreateCurrentOneRttDecrypter, - (), - (override)); - MOCK_METHOD(std::unique_ptr, - CreateCurrentOneRttEncrypter, - (), + AdvanceKeysAndCreateCurrentOneRttDecrypter, (), (override)); + MOCK_METHOD(std::unique_ptr, CreateCurrentOneRttEncrypter, (), (override)); MOCK_METHOD(void, BeforeConnectionCloseSent, (), (override)); - MOCK_METHOD(bool, ValidateToken, (absl::string_view), (const, override)); - MOCK_METHOD(void, MaybeSendAddressToken, (), (override)); + MOCK_METHOD(bool, ValidateToken, (absl::string_view), (override)); + MOCK_METHOD(bool, MaybeSendAddressToken, (), (override)); bool IsKnownServerAddress( const QuicSocketAddress& /*address*/) const override { @@ -672,36 +550,57 @@ class MockAlarmFactory : public QuicAlarmFactory { } }; +class TestAlarmFactory : public QuicAlarmFactory { + public: + class TestAlarm : public QuicAlarm { + public: + explicit TestAlarm(QuicArenaScopedPtr delegate) + : QuicAlarm(std::move(delegate)) {} + + void SetImpl() override {} + void CancelImpl() override {} + using QuicAlarm::Fire; + }; + + TestAlarmFactory() {} + TestAlarmFactory(const TestAlarmFactory&) = delete; + TestAlarmFactory& operator=(const TestAlarmFactory&) = delete; + + QuicAlarm* CreateAlarm(QuicAlarm::Delegate* delegate) override { + return new TestAlarm(QuicArenaScopedPtr(delegate)); + } + + QuicArenaScopedPtr CreateAlarm( + QuicArenaScopedPtr delegate, + QuicConnectionArena* arena) override { + return arena->New(std::move(delegate)); + } +}; + class MockQuicConnection : public QuicConnection { public: // Uses a ConnectionId of 42 and 127.0.0.1:123. MockQuicConnection(MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, - Perspective perspective); + MockAlarmFactory* alarm_factory, Perspective perspective); // Uses a ConnectionId of 42. MockQuicConnection(QuicSocketAddress address, MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, - Perspective perspective); + MockAlarmFactory* alarm_factory, Perspective perspective); // Uses 127.0.0.1:123. MockQuicConnection(QuicConnectionId connection_id, MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, - Perspective perspective); + MockAlarmFactory* alarm_factory, Perspective perspective); // Uses a ConnectionId of 42, and 127.0.0.1:123. MockQuicConnection(MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, - Perspective perspective, + MockAlarmFactory* alarm_factory, Perspective perspective, const ParsedQuicVersionVector& supported_versions); - MockQuicConnection(QuicConnectionId connection_id, - QuicSocketAddress address, + MockQuicConnection(QuicConnectionId connection_id, QuicSocketAddress address, MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, - Perspective perspective, + MockAlarmFactory* alarm_factory, Perspective perspective, const ParsedQuicVersionVector& supported_versions); MockQuicConnection(const MockQuicConnection&) = delete; MockQuicConnection& operator=(const MockQuicConnection&) = delete; @@ -713,66 +612,45 @@ class MockQuicConnection : public QuicConnection { // will advance the time of the MockClock. void AdvanceTime(QuicTime::Delta delta); - MOCK_METHOD(void, - ProcessUdpPacket, + MOCK_METHOD(void, ProcessUdpPacket, (const QuicSocketAddress& self_address, const QuicSocketAddress& peer_address, const QuicReceivedPacket& packet), (override)); - MOCK_METHOD(void, - CloseConnection, - (QuicErrorCode error, - const std::string& details, + MOCK_METHOD(void, CloseConnection, + (QuicErrorCode error, const std::string& details, ConnectionCloseBehavior connection_close_behavior), (override)); - MOCK_METHOD(void, - CloseConnection, - (QuicErrorCode error, - QuicIetfTransportErrorCodes ietf_error, + MOCK_METHOD(void, CloseConnection, + (QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, const std::string& details, ConnectionCloseBehavior connection_close_behavior), (override)); - MOCK_METHOD(void, - SendConnectionClosePacket, - (QuicErrorCode error, - QuicIetfTransportErrorCodes ietf_error, + MOCK_METHOD(void, SendConnectionClosePacket, + (QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, const std::string& details), (override)); MOCK_METHOD(void, OnCanWrite, (), (override)); - MOCK_METHOD(void, - SendConnectivityProbingResponsePacket, - (const QuicSocketAddress& peer_address), - (override)); - MOCK_METHOD(bool, - SendConnectivityProbingPacket, + MOCK_METHOD(void, SendConnectivityProbingResponsePacket, + (const QuicSocketAddress& peer_address), (override)); + MOCK_METHOD(bool, SendConnectivityProbingPacket, (QuicPacketWriter*, const QuicSocketAddress& peer_address), (override)); - MOCK_METHOD(void, - OnSendConnectionState, - (const CachedNetworkParameters&), - (override)); - MOCK_METHOD(void, - ResumeConnectionState, - (const CachedNetworkParameters&, bool), + MOCK_METHOD(void, OnSendConnectionState, (const CachedNetworkParameters&), (override)); + MOCK_METHOD(void, ResumeConnectionState, + (const CachedNetworkParameters&, bool), (override)); MOCK_METHOD(void, SetMaxPacingRate, (QuicBandwidth), (override)); - MOCK_METHOD(void, - OnStreamReset, - (QuicStreamId, QuicRstStreamErrorCode), + MOCK_METHOD(void, OnStreamReset, (QuicStreamId, QuicRstStreamErrorCode), (override)); MOCK_METHOD(bool, SendControlFrame, (const QuicFrame& frame), (override)); - MOCK_METHOD(MessageStatus, - SendMessage, - (QuicMessageId, QuicMemSliceSpan, bool), - (override)); - MOCK_METHOD(bool, - SendPathChallenge, - (const QuicPathFrameBuffer&, - const QuicSocketAddress&, - const QuicSocketAddress&, - const QuicSocketAddress&, + MOCK_METHOD(MessageStatus, SendMessage, + (QuicMessageId, absl::Span, bool), (override)); + MOCK_METHOD(bool, SendPathChallenge, + (const QuicPathFrameBuffer&, const QuicSocketAddress&, + const QuicSocketAddress&, const QuicSocketAddress&, QuicPacketWriter*), (override)); @@ -784,8 +662,7 @@ class MockQuicConnection : public QuicConnection { void ReallyOnCanWrite() { QuicConnection::OnCanWrite(); } void ReallyCloseConnection( - QuicErrorCode error, - const std::string& details, + QuicErrorCode error, const std::string& details, ConnectionCloseBehavior connection_close_behavior) { // Call the 4-param method directly instead of the 3-param method, so that // it doesn't invoke the virtual 4-param method causing the mock 4-param @@ -795,8 +672,7 @@ class MockQuicConnection : public QuicConnection { } void ReallyCloseConnection4( - QuicErrorCode error, - QuicIetfTransportErrorCodes ietf_error, + QuicErrorCode error, QuicIetfTransportErrorCodes ietf_error, const std::string& details, ConnectionCloseBehavior connection_close_behavior) { QuicConnection::CloseConnection(error, ietf_error, details, @@ -822,8 +698,7 @@ class MockQuicConnection : public QuicConnection { } bool ReallySendConnectivityProbingPacket( - QuicPacketWriter* probing_writer, - const QuicSocketAddress& peer_address) { + QuicPacketWriter* probing_writer, const QuicSocketAddress& peer_address) { return QuicConnection::SendConnectivityProbingPacket(probing_writer, peer_address); } @@ -837,18 +712,12 @@ class MockQuicConnection : public QuicConnection { return QuicConnection::OnPathResponseFrame(frame); } - MOCK_METHOD(bool, - OnPathResponseFrame, - (const QuicPathResponseFrame&), - (override)); - MOCK_METHOD(bool, - OnStopSendingFrame, - (const QuicStopSendingFrame& frame), + MOCK_METHOD(bool, OnPathResponseFrame, (const QuicPathResponseFrame&), (override)); - MOCK_METHOD(size_t, - SendCryptoData, - (EncryptionLevel, size_t, QuicStreamOffset), + MOCK_METHOD(bool, OnStopSendingFrame, (const QuicStopSendingFrame& frame), (override)); + MOCK_METHOD(size_t, SendCryptoData, + (EncryptionLevel, size_t, QuicStreamOffset), (override)); size_t QuicConnection_SendCryptoData(EncryptionLevel level, size_t write_length, QuicStreamOffset offset) { @@ -894,63 +763,47 @@ class MockQuicSession : public QuicSession { const QuicCryptoStream* GetCryptoStream() const override; void SetCryptoStream(QuicCryptoStream* crypto_stream); - MOCK_METHOD(void, - OnConnectionClosed, + MOCK_METHOD(void, OnConnectionClosed, (const QuicConnectionCloseFrame& frame, ConnectionCloseSource source), (override)); MOCK_METHOD(QuicStream*, CreateIncomingStream, (QuicStreamId id), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (PendingStream*), - (override)); - MOCK_METHOD(QuicConsumedData, - WritevData, - (QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, - TransmissionType type, - absl::optional level), - (override)); - MOCK_METHOD(bool, - WriteControlFrame, - (const QuicFrame& frame, TransmissionType type), - (override)); - MOCK_METHOD(void, - MaybeSendRstStreamFrame, - (QuicStreamId stream_id, - QuicRstStreamErrorCode error, - QuicStreamOffset bytes_written), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), + (override)); + MOCK_METHOD(QuicConsumedData, WritevData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + EncryptionLevel level), (override)); - MOCK_METHOD(void, - MaybeSendStopSendingFrame, - (QuicStreamId stream_id, QuicRstStreamErrorCode error), + MOCK_METHOD(bool, WriteControlFrame, + (const QuicFrame& frame, TransmissionType type), (override)); + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, + QuicStreamOffset bytes_written), (override)); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); MOCK_METHOD(bool, ShouldKeepConnectionAlive, (), (const, override)); MOCK_METHOD(std::vector, GetAlpnsToOffer, (), (const, override)); - MOCK_METHOD(std::vector::const_iterator, - SelectAlpn, - (const std::vector&), - (const, override)); + MOCK_METHOD(std::vector::const_iterator, SelectAlpn, + (const std::vector&), (const, override)); MOCK_METHOD(void, OnAlpnSelected, (absl::string_view), (override)); using QuicSession::ActivateStream; // Returns a QuicConsumedData that indicates all of |write_length| (and |fin| // if set) has been consumed. - QuicConsumedData ConsumeData(QuicStreamId id, - size_t write_length, + QuicConsumedData ConsumeData(QuicStreamId id, size_t write_length, QuicStreamOffset offset, - StreamSendingState state, - TransmissionType type, + StreamSendingState state, TransmissionType type, absl::optional level); void ReallyMaybeSendRstStreamFrame(QuicStreamId id, QuicRstStreamErrorCode error, QuicStreamOffset bytes_written) { - QuicSession::MaybeSendRstStreamFrame(id, error, bytes_written); + QuicSession::MaybeSendRstStreamFrame( + id, QuicResetStreamError::FromInternal(error), bytes_written); } private: @@ -974,10 +827,19 @@ class MockQuicCryptoStream : public QuicCryptoStream { void OnHandshakePacketSent() override {} void OnHandshakeDoneReceived() override {} void OnNewTokenReceived(absl::string_view /*token*/) override {} - std::string GetAddressToken() const override { return ""; } + std::string GetAddressToken( + const CachedNetworkParameters* /*cached_network_parameters*/) + const override { + return ""; + } bool ValidateAddressToken(absl::string_view /*token*/) const override { return true; } + const CachedNetworkParameters* PreviousCachedNetworkParams() const override { + return nullptr; + } + void SetPreviousCachedNetworkParams( + CachedNetworkParameters /*cached_network_params*/) override {} void OnConnectionClosed(QuicErrorCode /*error*/, ConnectionCloseSource /*source*/) override {} HandshakeState GetHandshakeState() const override { return HANDSHAKE_START; } @@ -991,6 +853,13 @@ class MockQuicCryptoStream : public QuicCryptoStream { std::unique_ptr CreateCurrentOneRttEncrypter() override { return nullptr; } + bool ExportKeyingMaterial(absl::string_view /*label*/, + absl::string_view /*context*/, + size_t /*result_len*/, + std::string* /*result*/) override { + return false; + } + SSL* GetSsl() const override { return nullptr; } private: QuicReferenceCountedPointer params_; @@ -1018,86 +887,57 @@ class MockQuicSpdySession : public QuicSpdySession { } // From QuicSession. - MOCK_METHOD(void, - OnConnectionClosed, + MOCK_METHOD(void, OnConnectionClosed, (const QuicConnectionCloseFrame& frame, ConnectionCloseSource source), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (QuicStreamId id), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (PendingStream*), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingBidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingUnidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), (override)); MOCK_METHOD(bool, ShouldCreateIncomingStream, (QuicStreamId id), (override)); MOCK_METHOD(bool, ShouldCreateOutgoingBidirectionalStream, (), (override)); MOCK_METHOD(bool, ShouldCreateOutgoingUnidirectionalStream, (), (override)); - MOCK_METHOD(QuicConsumedData, - WritevData, - (QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, - TransmissionType type, - absl::optional level), - (override)); - MOCK_METHOD(void, - MaybeSendRstStreamFrame, - (QuicStreamId stream_id, - QuicRstStreamErrorCode error, - QuicStreamOffset bytes_written), - (override)); - MOCK_METHOD(void, - MaybeSendStopSendingFrame, - (QuicStreamId stream_id, QuicRstStreamErrorCode error), + MOCK_METHOD(QuicConsumedData, WritevData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + EncryptionLevel level), (override)); - MOCK_METHOD(void, - SendWindowUpdate, - (QuicStreamId id, QuicStreamOffset byte_offset), + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, + QuicStreamOffset bytes_written), (override)); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); + MOCK_METHOD(void, SendWindowUpdate, + (QuicStreamId id, QuicStreamOffset byte_offset), (override)); MOCK_METHOD(void, SendBlocked, (QuicStreamId id), (override)); - MOCK_METHOD(void, - OnStreamHeadersPriority, + MOCK_METHOD(void, OnStreamHeadersPriority, (QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence), (override)); - MOCK_METHOD(void, - OnStreamHeaderList, - (QuicStreamId stream_id, - bool fin, - size_t frame_len, + MOCK_METHOD(void, OnStreamHeaderList, + (QuicStreamId stream_id, bool fin, size_t frame_len, const QuicHeaderList& header_list), (override)); - MOCK_METHOD(void, - OnPromiseHeaderList, - (QuicStreamId stream_id, - QuicStreamId promised_stream_id, - size_t frame_len, - const QuicHeaderList& header_list), + MOCK_METHOD(void, OnPromiseHeaderList, + (QuicStreamId stream_id, QuicStreamId promised_stream_id, + size_t frame_len, const QuicHeaderList& header_list), (override)); - MOCK_METHOD(void, - OnPriorityFrame, + MOCK_METHOD(void, OnPriorityFrame, (QuicStreamId id, const spdy::SpdyStreamPrecedence& precedence), (override)); MOCK_METHOD(void, OnCongestionWindowChange, (QuicTime now), (override)); // Returns a QuicConsumedData that indicates all of |write_length| (and |fin| // if set) has been consumed. - QuicConsumedData ConsumeData(QuicStreamId id, - size_t write_length, + QuicConsumedData ConsumeData(QuicStreamId id, size_t write_length, QuicStreamOffset offset, - StreamSendingState state, - TransmissionType type, + StreamSendingState state, TransmissionType type, absl::optional level); using QuicSession::ActivateStream; @@ -1112,89 +952,45 @@ class MockHttp3DebugVisitor : public Http3DebugVisitor { MOCK_METHOD(void, OnQpackEncoderStreamCreated, (QuicStreamId), (override)); MOCK_METHOD(void, OnQpackDecoderStreamCreated, (QuicStreamId), (override)); MOCK_METHOD(void, OnPeerControlStreamCreated, (QuicStreamId), (override)); - MOCK_METHOD(void, - OnPeerQpackEncoderStreamCreated, - (QuicStreamId), + MOCK_METHOD(void, OnPeerQpackEncoderStreamCreated, (QuicStreamId), (override)); - MOCK_METHOD(void, - OnPeerQpackDecoderStreamCreated, - (QuicStreamId), + MOCK_METHOD(void, OnPeerQpackDecoderStreamCreated, (QuicStreamId), (override)); - MOCK_METHOD(void, - OnSettingsFrameReceivedViaAlps, - (const SettingsFrame&), + MOCK_METHOD(void, OnSettingsFrameReceivedViaAlps, (const SettingsFrame&), (override)); - MOCK_METHOD(void, - OnAcceptChFrameReceivedViaAlps, - (const AcceptChFrame&), + MOCK_METHOD(void, OnAcceptChFrameReceivedViaAlps, (const AcceptChFrame&), (override)); - MOCK_METHOD(void, - OnCancelPushFrameReceived, - (const CancelPushFrame&), - (override)); - MOCK_METHOD(void, - OnSettingsFrameReceived, - (const SettingsFrame&), + MOCK_METHOD(void, OnSettingsFrameReceived, (const SettingsFrame&), (override)); MOCK_METHOD(void, OnGoAwayFrameReceived, (const GoAwayFrame&), (override)); - MOCK_METHOD(void, - OnMaxPushIdFrameReceived, - (const MaxPushIdFrame&), + MOCK_METHOD(void, OnMaxPushIdFrameReceived, (const MaxPushIdFrame&), (override)); - MOCK_METHOD(void, - OnPriorityUpdateFrameReceived, - (const PriorityUpdateFrame&), + MOCK_METHOD(void, OnPriorityUpdateFrameReceived, (const PriorityUpdateFrame&), (override)); - MOCK_METHOD(void, - OnAcceptChFrameReceived, - (const AcceptChFrame&), + MOCK_METHOD(void, OnAcceptChFrameReceived, (const AcceptChFrame&), (override)); - MOCK_METHOD(void, - OnDataFrameReceived, - (QuicStreamId, QuicByteCount), - (override)); - MOCK_METHOD(void, - OnHeadersFrameReceived, - (QuicStreamId, QuicByteCount), - (override)); - MOCK_METHOD(void, - OnHeadersDecoded, - (QuicStreamId, QuicHeaderList), + MOCK_METHOD(void, OnDataFrameReceived, (QuicStreamId, QuicByteCount), (override)); - MOCK_METHOD(void, - OnPushPromiseFrameReceived, - (QuicStreamId, QuicStreamId, QuicByteCount), + MOCK_METHOD(void, OnHeadersFrameReceived, (QuicStreamId, QuicByteCount), (override)); - MOCK_METHOD(void, - OnPushPromiseDecoded, - (QuicStreamId, QuicStreamId, QuicHeaderList), - (override)); - MOCK_METHOD(void, - OnUnknownFrameReceived, - (QuicStreamId, uint64_t, QuicByteCount), + MOCK_METHOD(void, OnHeadersDecoded, (QuicStreamId, QuicHeaderList), (override)); + MOCK_METHOD(void, OnUnknownFrameReceived, + (QuicStreamId, uint64_t, QuicByteCount), (override)); MOCK_METHOD(void, OnSettingsFrameSent, (const SettingsFrame&), (override)); MOCK_METHOD(void, OnGoAwayFrameSent, (QuicStreamId), (override)); MOCK_METHOD(void, OnMaxPushIdFrameSent, (const MaxPushIdFrame&), (override)); - MOCK_METHOD(void, - OnPriorityUpdateFrameSent, - (const PriorityUpdateFrame&), + MOCK_METHOD(void, OnPriorityUpdateFrameSent, (const PriorityUpdateFrame&), (override)); MOCK_METHOD(void, OnDataFrameSent, (QuicStreamId, QuicByteCount), (override)); - MOCK_METHOD(void, - OnHeadersFrameSent, - (QuicStreamId, const spdy::SpdyHeaderBlock&), - (override)); - MOCK_METHOD(void, - OnPushPromiseFrameSent, - (QuicStreamId, QuicStreamId, const spdy::SpdyHeaderBlock&), - (override)); + MOCK_METHOD(void, OnHeadersFrameSent, + (QuicStreamId, const spdy::SpdyHeaderBlock&), (override)); }; class TestQuicSpdyServerSession : public QuicServerSessionBase { @@ -1210,26 +1006,16 @@ class TestQuicSpdyServerSession : public QuicServerSessionBase { delete; ~TestQuicSpdyServerSession() override; - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (QuicStreamId id), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (PendingStream*), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingBidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingUnidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), (override)); - MOCK_METHOD(std::vector::const_iterator, - SelectAlpn, - (const std::vector&), - (const, override)); + MOCK_METHOD(std::vector::const_iterator, SelectAlpn, + (const std::vector&), (const, override)); MOCK_METHOD(void, OnAlpnSelected, (absl::string_view), (override)); std::unique_ptr CreateQuicCryptoServerStream( const QuicCryptoServerConfig* crypto_config, @@ -1241,9 +1027,35 @@ class TestQuicSpdyServerSession : public QuicServerSessionBase { MockQuicCryptoServerStreamHelper* helper() { return &helper_; } + QuicSSLConfig GetSSLConfig() const override { + QuicSSLConfig ssl_config = QuicServerSessionBase::GetSSLConfig(); + if (early_data_enabled_.has_value()) { + ssl_config.early_data_enabled = *early_data_enabled_; + } + if (client_cert_mode_.has_value()) { + ssl_config.client_cert_mode = *client_cert_mode_; + } + + return ssl_config; + } + + void set_early_data_enabled(bool enabled) { early_data_enabled_ = enabled; } + + void set_client_cert_mode(ClientCertMode mode) { + if (support_client_cert()) { + client_cert_mode_ = mode; + } + } + private: MockQuicSessionVisitor visitor_; MockQuicCryptoServerStreamHelper helper_; + // If not nullopt, override the early_data_enabled value from base class' + // ssl_config. + absl::optional early_data_enabled_; + // If not nullopt, override the client_cert_mode value from base class' + // ssl_config. + absl::optional client_cert_mode_; }; // A test implementation of QuicClientPushPromiseIndex::Delegate. @@ -1283,31 +1095,19 @@ class TestQuicSpdyClientSession : public QuicSpdyClientSessionBase { bool IsAuthorized(const std::string& authority) override; // QuicSpdyClientSessionBase - MOCK_METHOD(void, - OnProofValid, - (const QuicCryptoClientConfig::CachedState& cached), - (override)); - MOCK_METHOD(void, - OnProofVerifyDetailsAvailable, - (const ProofVerifyDetails& verify_details), - (override)); + MOCK_METHOD(void, OnProofValid, + (const QuicCryptoClientConfig::CachedState& cached), (override)); + MOCK_METHOD(void, OnProofVerifyDetailsAvailable, + (const ProofVerifyDetails& verify_details), (override)); // TestQuicSpdyClientSession - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (QuicStreamId id), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (PendingStream*), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (PendingStream*), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingBidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingBidirectionalStream, (), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateOutgoingUnidirectionalStream, - (), + MOCK_METHOD(QuicSpdyStream*, CreateOutgoingUnidirectionalStream, (), (override)); MOCK_METHOD(bool, ShouldCreateIncomingStream, (QuicStreamId id), (override)); MOCK_METHOD(bool, ShouldCreateOutgoingBidirectionalStream, (), (override)); @@ -1347,24 +1147,17 @@ class MockPacketWriter : public QuicPacketWriter { MockPacketWriter& operator=(const MockPacketWriter&) = delete; ~MockPacketWriter() override; - MOCK_METHOD(WriteResult, - WritePacket, - (const char*, - size_t buf_len, - const QuicIpAddress& self_address, - const QuicSocketAddress& peer_address, - PerPacketOptions*), + MOCK_METHOD(WriteResult, WritePacket, + (const char*, size_t buf_len, const QuicIpAddress& self_address, + const QuicSocketAddress& peer_address, PerPacketOptions*), (override)); MOCK_METHOD(bool, IsWriteBlocked, (), (const, override)); MOCK_METHOD(void, SetWritable, (), (override)); - MOCK_METHOD(QuicByteCount, - GetMaxPacketSize, - (const QuicSocketAddress& peer_address), - (const, override)); + MOCK_METHOD(QuicByteCount, GetMaxPacketSize, + (const QuicSocketAddress& peer_address), (const, override)); MOCK_METHOD(bool, SupportsReleaseTime, (), (const, override)); MOCK_METHOD(bool, IsBatchMode, (), (const, override)); - MOCK_METHOD(QuicPacketBuffer, - GetNextWriteLocation, + MOCK_METHOD(QuicPacketBuffer, GetNextWriteLocation, (const QuicIpAddress& self_address, const QuicSocketAddress& peer_address), (override)); @@ -1378,32 +1171,19 @@ class MockSendAlgorithm : public SendAlgorithmInterface { MockSendAlgorithm& operator=(const MockSendAlgorithm&) = delete; ~MockSendAlgorithm() override; - MOCK_METHOD(void, - SetFromConfig, - (const QuicConfig& config, Perspective perspective), - (override)); - MOCK_METHOD(void, - ApplyConnectionOptions, - (const QuicTagVector& connection_options), - (override)); - MOCK_METHOD(void, - SetInitialCongestionWindowInPackets, - (QuicPacketCount packets), - (override)); - MOCK_METHOD(void, - OnCongestionEvent, - (bool rtt_updated, - QuicByteCount bytes_in_flight, - QuicTime event_time, - const AckedPacketVector& acked_packets, + MOCK_METHOD(void, SetFromConfig, + (const QuicConfig& config, Perspective perspective), (override)); + MOCK_METHOD(void, ApplyConnectionOptions, + (const QuicTagVector& connection_options), (override)); + MOCK_METHOD(void, SetInitialCongestionWindowInPackets, + (QuicPacketCount packets), (override)); + MOCK_METHOD(void, OnCongestionEvent, + (bool rtt_updated, QuicByteCount bytes_in_flight, + QuicTime event_time, const AckedPacketVector& acked_packets, const LostPacketVector& lost_packets), (override)); - MOCK_METHOD(void, - OnPacketSent, - (QuicTime, - QuicByteCount, - QuicPacketNumber, - QuicByteCount, + MOCK_METHOD(void, OnPacketSent, + (QuicTime, QuicByteCount, QuicPacketNumber, QuicByteCount, HasRetransmittableData), (override)); MOCK_METHOD(void, OnPacketNeutered, (QuicPacketNumber), (override)); @@ -1418,18 +1198,12 @@ class MockSendAlgorithm : public SendAlgorithmInterface { MOCK_METHOD(bool, InRecovery, (), (const, override)); MOCK_METHOD(bool, ShouldSendProbingPacket, (), (const, override)); MOCK_METHOD(QuicByteCount, GetSlowStartThreshold, (), (const, override)); - MOCK_METHOD(CongestionControlType, - GetCongestionControlType, - (), + MOCK_METHOD(CongestionControlType, GetCongestionControlType, (), (const, override)); - MOCK_METHOD(void, - AdjustNetworkParameters, - (const NetworkParams&), + MOCK_METHOD(void, AdjustNetworkParameters, (const NetworkParams&), (override)); MOCK_METHOD(void, OnApplicationLimited, (QuicByteCount), (override)); - MOCK_METHOD(void, - PopulateConnectionStats, - (QuicConnectionStats*), + MOCK_METHOD(void, PopulateConnectionStats, (QuicConnectionStats*), (const, override)); }; @@ -1440,28 +1214,19 @@ class MockLossAlgorithm : public LossDetectionInterface { MockLossAlgorithm& operator=(const MockLossAlgorithm&) = delete; ~MockLossAlgorithm() override; - MOCK_METHOD(void, - SetFromConfig, - (const QuicConfig& config, Perspective perspective), - (override)); + MOCK_METHOD(void, SetFromConfig, + (const QuicConfig& config, Perspective perspective), (override)); - MOCK_METHOD(DetectionStats, - DetectLosses, - (const QuicUnackedPacketMap& unacked_packets, - QuicTime time, + MOCK_METHOD(DetectionStats, DetectLosses, + (const QuicUnackedPacketMap& unacked_packets, QuicTime time, const RttStats& rtt_stats, QuicPacketNumber largest_recently_acked, - const AckedPacketVector& packets_acked, - LostPacketVector*), + const AckedPacketVector& packets_acked, LostPacketVector*), (override)); MOCK_METHOD(QuicTime, GetLossTimeout, (), (const, override)); - MOCK_METHOD(void, - SpuriousLossDetected, - (const QuicUnackedPacketMap&, - const RttStats&, - QuicTime, - QuicPacketNumber, - QuicPacketNumber), + MOCK_METHOD(void, SpuriousLossDetected, + (const QuicUnackedPacketMap&, const RttStats&, QuicTime, + QuicPacketNumber, QuicPacketNumber), (override)); MOCK_METHOD(void, OnConfigNegotiated, (), (override)); @@ -1477,14 +1242,10 @@ class MockAckListener : public QuicAckListenerInterface { MockAckListener(const MockAckListener&) = delete; MockAckListener& operator=(const MockAckListener&) = delete; - MOCK_METHOD(void, - OnPacketAcked, - (int acked_bytes, QuicTime::Delta ack_delay_time), - (override)); + MOCK_METHOD(void, OnPacketAcked, + (int acked_bytes, QuicTime::Delta ack_delay_time), (override)); - MOCK_METHOD(void, - OnPacketRetransmitted, - (int retransmitted_bytes), + MOCK_METHOD(void, OnPacketRetransmitted, (int retransmitted_bytes), (override)); protected: @@ -1509,29 +1270,18 @@ class MockQuicConnectionDebugVisitor : public QuicConnectionDebugVisitor { MockQuicConnectionDebugVisitor(); ~MockQuicConnectionDebugVisitor() override; - MOCK_METHOD(void, - OnPacketSent, - (QuicPacketNumber, - QuicPacketLength, - bool, - TransmissionType, - EncryptionLevel, - const QuicFrames&, - const QuicFrames&, - QuicTime), + MOCK_METHOD(void, OnPacketSent, + (QuicPacketNumber, QuicPacketLength, bool, TransmissionType, + EncryptionLevel, const QuicFrames&, const QuicFrames&, QuicTime), (override)); - MOCK_METHOD(void, - OnCoalescedPacketSent, - (const QuicCoalescedPacket&, size_t), + MOCK_METHOD(void, OnCoalescedPacketSent, (const QuicCoalescedPacket&, size_t), (override)); MOCK_METHOD(void, OnPingSent, (), (override)); - MOCK_METHOD(void, - OnPacketReceived, - (const QuicSocketAddress&, - const QuicSocketAddress&, + MOCK_METHOD(void, OnPacketReceived, + (const QuicSocketAddress&, const QuicSocketAddress&, const QuicEncryptedPacket&), (override)); @@ -1539,83 +1289,57 @@ class MockQuicConnectionDebugVisitor : public QuicConnectionDebugVisitor { MOCK_METHOD(void, OnProtocolVersionMismatch, (ParsedQuicVersion), (override)); - MOCK_METHOD(void, - OnPacketHeader, - (const QuicPacketHeader& header, - QuicTime receive_time, + MOCK_METHOD(void, OnPacketHeader, + (const QuicPacketHeader& header, QuicTime receive_time, EncryptionLevel level), (override)); - MOCK_METHOD(void, - OnSuccessfulVersionNegotiation, - (const ParsedQuicVersion&), + MOCK_METHOD(void, OnSuccessfulVersionNegotiation, (const ParsedQuicVersion&), (override)); MOCK_METHOD(void, OnStreamFrame, (const QuicStreamFrame&), (override)); MOCK_METHOD(void, OnCryptoFrame, (const QuicCryptoFrame&), (override)); - MOCK_METHOD(void, - OnStopWaitingFrame, - (const QuicStopWaitingFrame&), + MOCK_METHOD(void, OnStopWaitingFrame, (const QuicStopWaitingFrame&), (override)); MOCK_METHOD(void, OnRstStreamFrame, (const QuicRstStreamFrame&), (override)); - MOCK_METHOD(void, - OnConnectionCloseFrame, - (const QuicConnectionCloseFrame&), + MOCK_METHOD(void, OnConnectionCloseFrame, (const QuicConnectionCloseFrame&), (override)); MOCK_METHOD(void, OnBlockedFrame, (const QuicBlockedFrame&), (override)); - MOCK_METHOD(void, - OnNewConnectionIdFrame, - (const QuicNewConnectionIdFrame&), + MOCK_METHOD(void, OnNewConnectionIdFrame, (const QuicNewConnectionIdFrame&), (override)); - MOCK_METHOD(void, - OnRetireConnectionIdFrame, - (const QuicRetireConnectionIdFrame&), - (override)); + MOCK_METHOD(void, OnRetireConnectionIdFrame, + (const QuicRetireConnectionIdFrame&), (override)); MOCK_METHOD(void, OnNewTokenFrame, (const QuicNewTokenFrame&), (override)); MOCK_METHOD(void, OnMessageFrame, (const QuicMessageFrame&), (override)); - MOCK_METHOD(void, - OnStopSendingFrame, - (const QuicStopSendingFrame&), + MOCK_METHOD(void, OnStopSendingFrame, (const QuicStopSendingFrame&), (override)); - MOCK_METHOD(void, - OnPathChallengeFrame, - (const QuicPathChallengeFrame&), + MOCK_METHOD(void, OnPathChallengeFrame, (const QuicPathChallengeFrame&), (override)); - MOCK_METHOD(void, - OnPathResponseFrame, - (const QuicPathResponseFrame&), + MOCK_METHOD(void, OnPathResponseFrame, (const QuicPathResponseFrame&), (override)); - MOCK_METHOD(void, - OnPublicResetPacket, - (const QuicPublicResetPacket&), + MOCK_METHOD(void, OnPublicResetPacket, (const QuicPublicResetPacket&), (override)); - MOCK_METHOD(void, - OnVersionNegotiationPacket, - (const QuicVersionNegotiationPacket&), - (override)); + MOCK_METHOD(void, OnVersionNegotiationPacket, + (const QuicVersionNegotiationPacket&), (override)); - MOCK_METHOD(void, - OnTransportParametersSent, - (const TransportParameters&), + MOCK_METHOD(void, OnTransportParametersSent, (const TransportParameters&), (override)); - MOCK_METHOD(void, - OnTransportParametersReceived, - (const TransportParameters&), + MOCK_METHOD(void, OnTransportParametersReceived, (const TransportParameters&), (override)); MOCK_METHOD(void, OnZeroRttRejected, (int), (override)); @@ -1627,14 +1351,11 @@ class MockReceivedPacketManager : public QuicReceivedPacketManager { explicit MockReceivedPacketManager(QuicConnectionStats* stats); ~MockReceivedPacketManager() override; - MOCK_METHOD(void, - RecordPacketReceived, + MOCK_METHOD(void, RecordPacketReceived, (const QuicPacketHeader& header, QuicTime receipt_time), (override)); MOCK_METHOD(bool, IsMissing, (QuicPacketNumber packet_number), (override)); - MOCK_METHOD(bool, - IsAwaitingPacket, - (QuicPacketNumber packet_number), + MOCK_METHOD(bool, IsAwaitingPacket, (QuicPacketNumber packet_number), (const, override)); MOCK_METHOD(bool, HasNewMissingPackets, (), (const, override)); MOCK_METHOD(bool, ack_frame_updated, (), (const, override)); @@ -1650,22 +1371,15 @@ class MockPacketCreatorDelegate : public QuicPacketCreator::DelegateInterface { MOCK_METHOD(QuicPacketBuffer, GetPacketBuffer, (), (override)); MOCK_METHOD(void, OnSerializedPacket, (SerializedPacket), (override)); - MOCK_METHOD(void, - OnUnrecoverableError, - (QuicErrorCode, const std::string&), + MOCK_METHOD(void, OnUnrecoverableError, (QuicErrorCode, const std::string&), (override)); - MOCK_METHOD(bool, - ShouldGeneratePacket, + MOCK_METHOD(bool, ShouldGeneratePacket, (HasRetransmittableData retransmittable, IsHandshake handshake), (override)); - MOCK_METHOD(const QuicFrames, - MaybeBundleAckOpportunistically, - (), - (override)); - MOCK_METHOD(SerializedPacketFate, - GetSerializedPacketFate, - (bool, EncryptionLevel), + MOCK_METHOD(const QuicFrames, MaybeBundleAckOpportunistically, (), (override)); + MOCK_METHOD(SerializedPacketFate, GetSerializedPacketFate, + (bool, EncryptionLevel), (override)); }; class MockSessionNotifier : public SessionNotifierInterface { @@ -1673,19 +1387,13 @@ class MockSessionNotifier : public SessionNotifierInterface { MockSessionNotifier(); ~MockSessionNotifier() override; - MOCK_METHOD(bool, - OnFrameAcked, - (const QuicFrame&, QuicTime::Delta, QuicTime), + MOCK_METHOD(bool, OnFrameAcked, (const QuicFrame&, QuicTime::Delta, QuicTime), (override)); - MOCK_METHOD(void, - OnStreamFrameRetransmitted, - (const QuicStreamFrame&), + MOCK_METHOD(void, OnStreamFrameRetransmitted, (const QuicStreamFrame&), (override)); MOCK_METHOD(void, OnFrameLost, (const QuicFrame&), (override)); - MOCK_METHOD(void, - RetransmitFrames, - (const QuicFrames&, TransmissionType type), - (override)); + MOCK_METHOD(void, RetransmitFrames, + (const QuicFrames&, TransmissionType type), (override)); MOCK_METHOD(bool, IsFrameOutstanding, (const QuicFrame&), (const, override)); MOCK_METHOD(bool, HasUnackedCryptoData, (), (const, override)); MOCK_METHOD(bool, HasUnackedStreamData, (), (const, override)); @@ -1697,8 +1405,7 @@ class MockQuicPathValidationContext : public QuicPathValidationContext { const QuicSocketAddress& peer_address, const QuicSocketAddress& effective_peer_address, QuicPacketWriter* writer) - : QuicPathValidationContext(self_address, - peer_address, + : QuicPathValidationContext(self_address, peer_address, effective_peer_address), writer_(writer) {} QuicPacketWriter* WriterToUse() override { return writer_; } @@ -1710,15 +1417,11 @@ class MockQuicPathValidationContext : public QuicPathValidationContext { class MockQuicPathValidationResultDelegate : public QuicPathValidator::ResultDelegate { public: - MOCK_METHOD(void, - OnPathValidationSuccess, - (std::unique_ptr), - (override)); + MOCK_METHOD(void, OnPathValidationSuccess, + (std::unique_ptr), (override)); - MOCK_METHOD(void, - OnPathValidationFailure, - (std::unique_ptr), - (override)); + MOCK_METHOD(void, OnPathValidationFailure, + (std::unique_ptr), (override)); }; class QuicCryptoClientStreamPeer { @@ -1745,11 +1448,9 @@ class QuicCryptoClientStreamPeer { // client_session: Pointer reference for the newly created client // session. The new object will be owned by the caller. void CreateClientSessionForTest( - QuicServerId server_id, - QuicTime::Delta connection_start_time, + QuicServerId server_id, QuicTime::Delta connection_start_time, const ParsedQuicVersionVector& supported_versions, - MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, + MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, QuicCryptoClientConfig* crypto_client_config, PacketSavingConnection** client_connection, TestQuicSpdyClientSession** client_session); @@ -1770,11 +1471,9 @@ void CreateClientSessionForTest( // server_session: Pointer reference for the newly created server // session. The new object will be owned by the caller. void CreateServerSessionForTest( - QuicServerId server_id, - QuicTime::Delta connection_start_time, + QuicServerId server_id, QuicTime::Delta connection_start_time, ParsedQuicVersionVector supported_versions, - MockQuicConnectionHelper* helper, - MockAlarmFactory* alarm_factory, + MockQuicConnectionHelper* helper, MockAlarmFactory* alarm_factory, QuicCryptoServerConfig* crypto_server_config, QuicCompressedCertsCache* compressed_certs_cache, PacketSavingConnection** server_connection, @@ -1826,30 +1525,18 @@ inline void MakeIOVector(absl::string_view str, struct iovec* iov) { // HTTP stream numbering scheme (i.e. whether one or two QUIC streams are used // per HTTP transaction). QuicStreamId GetNthClientInitiatedBidirectionalStreamId( - QuicTransportVersion version, - int n); + QuicTransportVersion version, int n); QuicStreamId GetNthServerInitiatedBidirectionalStreamId( - QuicTransportVersion version, - int n); + QuicTransportVersion version, int n); QuicStreamId GetNthServerInitiatedUnidirectionalStreamId( - QuicTransportVersion version, - int n); + QuicTransportVersion version, int n); QuicStreamId GetNthClientInitiatedUnidirectionalStreamId( - QuicTransportVersion version, - int n); + QuicTransportVersion version, int n); -StreamType DetermineStreamType(QuicStreamId id, - ParsedQuicVersion version, - Perspective perspective, - bool is_incoming, +StreamType DetermineStreamType(QuicStreamId id, ParsedQuicVersion version, + Perspective perspective, bool is_incoming, StreamType default_type); -// Utility function that stores message_data in |storage| and returns a -// QuicMemSliceSpan. -QuicMemSliceSpan MakeSpan(QuicBufferAllocator* allocator, - absl::string_view message_data, - QuicMemSliceStorage* storage); - // Creates a MemSlice using a singleton trivial buffer allocator. Performs a // copy. QuicMemSlice MemSliceFromString(absl::string_view data); @@ -1863,15 +1550,12 @@ MATCHER_P(ReceivedPacketInfoConnectionIdEquals, destination_connection_id, "") { return arg.destination_connection_id == destination_connection_id; } -MATCHER_P2(InRange, min, max, "") { - return arg >= min && arg <= max; -} +MATCHER_P2(InRange, min, max, "") { return arg >= min && arg <= max; } // A GMock matcher that prints expected and actual QuicErrorCode strings // upon failure. Example usage: // EXPECT_THAT(stream_->connection_error(), IsError(QUIC_INTERNAL_ERROR)); -MATCHER_P(IsError, - expected, +MATCHER_P(IsError, expected, absl::StrCat(negation ? "isn't equal to " : "is equal to ", QuicErrorCodeToString(expected))) { *result_listener << QuicErrorCodeToString(static_cast(arg)); @@ -1890,8 +1574,7 @@ MATCHER(IsQuicNoError, // A GMock matcher that prints expected and actual QuicRstStreamErrorCode // strings upon failure. Example usage: // EXPECT_THAT(stream_->stream_error(), IsStreamError(QUIC_INTERNAL_ERROR)); -MATCHER_P(IsStreamError, - expected, +MATCHER_P(IsStreamError, expected, absl::StrCat(negation ? "isn't equal to " : "is equal to ", QuicRstStreamErrorCodeToString(expected))) { *result_listener << QuicRstStreamErrorCodeToString(arg); @@ -1929,12 +1612,9 @@ class TaggingEncrypter : public QuicEncrypter { return true; } - bool EncryptPacket(uint64_t packet_number, - absl::string_view associated_data, - absl::string_view plaintext, - char* output, - size_t* output_length, - size_t max_output_length) override; + bool EncryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view plaintext, char* output, + size_t* output_length, size_t max_output_length) override; std::string GenerateHeaderProtectionMask( absl::string_view /*sample*/) override { @@ -1999,12 +1679,9 @@ class TaggingDecrypter : public QuicDecrypter { return true; } - bool DecryptPacket(uint64_t packet_number, - absl::string_view associated_data, - absl::string_view ciphertext, - char* output, - size_t* output_length, - size_t max_output_length) override; + bool DecryptPacket(uint64_t packet_number, absl::string_view associated_data, + absl::string_view ciphertext, char* output, + size_t* output_length, size_t max_output_length) override; std::string GenerateHeaderProtectionMask( QuicDataReader* /*sample_reader*/) override { @@ -2061,8 +1738,7 @@ class TestPacketWriter : public QuicPacketWriter { }; public: - TestPacketWriter(ParsedQuicVersion version, - MockClock* clock, + TestPacketWriter(ParsedQuicVersion version, MockClock* clock, Perspective perspective); TestPacketWriter(const TestPacketWriter&) = delete; @@ -2071,8 +1747,7 @@ class TestPacketWriter : public QuicPacketWriter { ~TestPacketWriter() override; // QuicPacketWriter interface - WriteResult WritePacket(const char* buffer, - size_t buf_len, + WriteResult WritePacket(const char* buffer, size_t buf_len, const QuicIpAddress& self_address, const QuicSocketAddress& peer_address, PerPacketOptions* options) override; @@ -2298,8 +1973,7 @@ class TestPacketWriter : public QuicPacketWriter { // contents of the destination connection ID passed in to // WriteClientVersionNegotiationProbePacket. bool ParseClientVersionNegotiationProbePacket( - const char* packet_bytes, - size_t packet_length, + const char* packet_bytes, size_t packet_length, char* destination_connection_id_bytes, uint8_t* destination_connection_id_length_out); @@ -2312,37 +1986,55 @@ bool ParseClientVersionNegotiationProbePacket( // the source connection ID, and must point to |source_connection_id_length| // bytes of memory. bool WriteServerVersionNegotiationProbeResponse( - char* packet_bytes, - size_t* packet_length_out, + char* packet_bytes, size_t* packet_length_out, const char* source_connection_id_bytes, uint8_t source_connection_id_length); // Implementation of Http3DatagramVisitor which saves all received datagrams. -class SavingHttp3DatagramVisitor - : public QuicSpdySession::Http3DatagramVisitor { +class SavingHttp3DatagramVisitor : public QuicSpdyStream::Http3DatagramVisitor { public: struct SavedHttp3Datagram { - QuicDatagramFlowId flow_id; + QuicStreamId stream_id; + absl::optional context_id; std::string payload; bool operator==(const SavedHttp3Datagram& o) const { - return flow_id == o.flow_id && payload == o.payload; + return stream_id == o.stream_id && context_id == o.context_id && + payload == o.payload; } }; const std::vector& received_h3_datagrams() const { return received_h3_datagrams_; } - // Override from QuicSpdySession::Http3DatagramVisitor. - void OnHttp3Datagram(QuicDatagramFlowId flow_id, + // Override from QuicSpdyStream::Http3DatagramVisitor. + void OnHttp3Datagram(QuicStreamId stream_id, + absl::optional context_id, absl::string_view payload) override { received_h3_datagrams_.push_back( - SavedHttp3Datagram{flow_id, std::string(payload)}); + SavedHttp3Datagram{stream_id, context_id, std::string(payload)}); } private: std::vector received_h3_datagrams_; }; +class MockHttp3DatagramRegistrationVisitor + : public QuicSpdyStream::Http3DatagramRegistrationVisitor { + public: + MOCK_METHOD(void, OnContextReceived, + (QuicStreamId stream_id, + absl::optional context_id, + DatagramFormatType format_type, + absl::string_view format_additional_data), + (override)); + + MOCK_METHOD(void, OnContextClosed, + (QuicStreamId stream_id, + absl::optional context_id, + ContextCloseCode close_code, absl::string_view close_details), + (override)); +}; + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.cc b/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.cc index dd4677f8..946bf133 100644 --- a/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.cc +++ b/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.cc @@ -36,5 +36,11 @@ bool QuicTimeWaitListManagerPeer::SendOrQueuePacket( return manager->SendOrQueuePacket(std::move(packet), packet_context); } +// static +size_t QuicTimeWaitListManagerPeer::PendingPacketsQueueSize( + QuicTimeWaitListManager* manager) { + return manager->pending_packets_queue_.size(); +} + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.h b/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.h index 2a903c34..f104e0f7 100644 --- a/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.h +++ b/gquiche/quic/test_tools/quic_time_wait_list_manager_peer.h @@ -26,6 +26,8 @@ class QuicTimeWaitListManagerPeer { QuicTimeWaitListManager* manager, std::unique_ptr packet, const QuicPerPacketContext* packet_context); + + static size_t PendingPacketsQueueSize(QuicTimeWaitListManager* manager); }; } // namespace test diff --git a/gquiche/quic/test_tools/quic_transport_test_tools.h b/gquiche/quic/test_tools/quic_transport_test_tools.h index beee4a2b..54a32a90 100644 --- a/gquiche/quic/test_tools/quic_transport_test_tools.h +++ b/gquiche/quic/test_tools/quic_transport_test_tools.h @@ -14,7 +14,9 @@ namespace test { class MockClientVisitor : public WebTransportVisitor { public: - MOCK_METHOD(void, OnSessionReady, (), (override)); + MOCK_METHOD(void, OnSessionReady, (const spdy::SpdyHeaderBlock&), (override)); + MOCK_METHOD(void, OnSessionClosed, + (WebTransportSessionError, const std::string&), (override)); MOCK_METHOD(void, OnIncomingBidirectionalStreamAvailable, (), (override)); MOCK_METHOD(void, OnIncomingUnidirectionalStreamAvailable, (), (override)); MOCK_METHOD(void, OnDatagramReceived, (absl::string_view), (override)); @@ -32,6 +34,12 @@ class MockStreamVisitor : public WebTransportStreamVisitor { public: MOCK_METHOD(void, OnCanRead, (), (override)); MOCK_METHOD(void, OnCanWrite, (), (override)); + + MOCK_METHOD(void, OnResetStreamReceived, (WebTransportStreamError error), + (override)); + MOCK_METHOD(void, OnStopSendingReceived, (WebTransportStreamError error), + (override)); + MOCK_METHOD(void, OnWriteSideInDataRecvdState, (), (override)); }; } // namespace test diff --git a/gquiche/quic/test_tools/server_thread.cc b/gquiche/quic/test_tools/server_thread.cc index 7ba245f8..6d046b0e 100644 --- a/gquiche/quic/test_tools/server_thread.cc +++ b/gquiche/quic/test_tools/server_thread.cc @@ -126,7 +126,7 @@ void ServerThread::MaybeNotifyOfHandshakeConfirmation() { } void ServerThread::ExecuteScheduledActions() { - QuicCircularDeque> actions; + quiche::QuicheCircularDeque> actions; { QuicWriterMutexLock lock(&scheduled_actions_lock_); actions.swap(scheduled_actions_); diff --git a/gquiche/quic/test_tools/server_thread.h b/gquiche/quic/test_tools/server_thread.h index 997a5055..22d4442c 100644 --- a/gquiche/quic/test_tools/server_thread.h +++ b/gquiche/quic/test_tools/server_thread.h @@ -86,7 +86,7 @@ class ServerThread : public QuicThread { bool initialized_; QuicMutex scheduled_actions_lock_; - QuicCircularDeque> scheduled_actions_ + quiche::QuicheCircularDeque> scheduled_actions_ QUIC_GUARDED_BY(scheduled_actions_lock_); }; diff --git a/gquiche/quic/test_tools/simple_data_producer.cc b/gquiche/quic/test_tools/simple_data_producer.cc index 91fdf23f..d0f3a590 100644 --- a/gquiche/quic/test_tools/simple_data_producer.cc +++ b/gquiche/quic/test_tools/simple_data_producer.cc @@ -10,7 +10,6 @@ #include "gquiche/quic/core/quic_data_writer.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/quic/platform/api/quic_map_util.h" namespace quic { @@ -28,7 +27,7 @@ void SimpleDataProducer::SaveStreamData(QuicStreamId id, if (data_length == 0) { return; } - if (!QuicContainsKey(send_buffer_map_, id)) { + if (!send_buffer_map_.contains(id)) { send_buffer_map_[id] = std::make_unique(&allocator_); } send_buffer_map_[id]->SaveStreamData(iov, iov_count, iov_offset, data_length); diff --git a/gquiche/quic/test_tools/simple_data_producer.h b/gquiche/quic/test_tools/simple_data_producer.h index 498dec4c..8a805fa3 100644 --- a/gquiche/quic/test_tools/simple_data_producer.h +++ b/gquiche/quic/test_tools/simple_data_producer.h @@ -49,23 +49,13 @@ class SimpleDataProducer : public QuicStreamFrameDataProducer { QuicByteCount data_length, QuicDataWriter* writer) override; - // TODO(wub): Allow QuicDefaultHasher to accept a pair. Then remove this. - class PairHash { - public: - template - size_t operator()(const std::pair& pair) const { - return std::hash()(pair.first) ^ std::hash()(pair.second); - } - }; - private: using SendBufferMap = absl::flat_hash_map>; using CryptoBufferMap = absl::flat_hash_map, - absl::string_view, - PairHash>; + absl::string_view>; SimpleBufferAllocator allocator_; diff --git a/gquiche/quic/test_tools/simple_session_cache.cc b/gquiche/quic/test_tools/simple_session_cache.cc index a4b5bb55..c53f959e 100644 --- a/gquiche/quic/test_tools/simple_session_cache.cc +++ b/gquiche/quic/test_tools/simple_session_cache.cc @@ -28,7 +28,7 @@ void SimpleSessionCache::Insert(const QuicServerId& server_id, } std::unique_ptr SimpleSessionCache::Lookup( - const QuicServerId& server_id, + const QuicServerId& server_id, QuicWallTime /*now*/, const SSL_CTX* /*ctx*/) { auto it = cache_entries_.find(server_id); if (it == cache_entries_.end()) { @@ -48,6 +48,7 @@ std::unique_ptr SimpleSessionCache::Lookup( } state->transport_params = std::make_unique(*it->second.params); + state->token = it->second.token; return state; } @@ -56,5 +57,20 @@ void SimpleSessionCache::ClearEarlyData(const QuicServerId& /*server_id*/) { // do anything here. } +void SimpleSessionCache::OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) { + auto it = cache_entries_.find(server_id); + if (it == cache_entries_.end()) { + return; + } + it->second.token = std::string(token); +} + +void SimpleSessionCache::RemoveExpiredEntries(QuicWallTime /*now*/) { + // The simple session cache does not support removing expired entries. +} + +void SimpleSessionCache::Clear() { cache_entries_.clear(); } + } // namespace test } // namespace quic diff --git a/gquiche/quic/test_tools/simple_session_cache.h b/gquiche/quic/test_tools/simple_session_cache.h index 47e7f5e1..0e36354f 100644 --- a/gquiche/quic/test_tools/simple_session_cache.h +++ b/gquiche/quic/test_tools/simple_session_cache.h @@ -17,6 +17,7 @@ namespace test { // the total number of entries in the cache. When Lookup is called, if a cache // entry exists for the provided QuicServerId, the entry will be removed from // the cached when it is returned. +// TODO(fayang): Remove SimpleSessionCache by using QuicClientSessionCache. class SimpleSessionCache : public SessionCache { public: SimpleSessionCache() = default; @@ -27,14 +28,20 @@ class SimpleSessionCache : public SessionCache { const TransportParameters& params, const ApplicationState* application_state) override; std::unique_ptr Lookup(const QuicServerId& server_id, + QuicWallTime now, const SSL_CTX* ctx) override; void ClearEarlyData(const QuicServerId& server_id) override; + void OnNewTokenReceived(const QuicServerId& server_id, + absl::string_view token) override; + void RemoveExpiredEntries(QuicWallTime now) override; + void Clear() override; private: struct Entry { bssl::UniquePtr session; std::unique_ptr params; std::unique_ptr application_state; + std::string token; }; std::map cache_entries_; }; diff --git a/gquiche/quic/test_tools/simple_session_notifier.cc b/gquiche/quic/test_tools/simple_session_notifier.cc index e80aa1e1..653649d7 100644 --- a/gquiche/quic/test_tools/simple_session_notifier.cc +++ b/gquiche/quic/test_tools/simple_session_notifier.cc @@ -6,7 +6,6 @@ #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/test_tools/quic_test_utils.h" namespace quic { @@ -40,7 +39,7 @@ QuicConsumedData SimpleSessionNotifier::WriteOrBufferData( QuicStreamId id, QuicByteCount data_length, StreamSendingState state) { - if (!QuicContainsKey(stream_map_, id)) { + if (!stream_map_.contains(id)) { stream_map_[id] = StreamState(); } StreamState& stream_state = stream_map_.find(id)->second; @@ -112,6 +111,21 @@ void SimpleSessionNotifier::WriteOrBufferRstStream( WriteBufferedControlFrames(); } +void SimpleSessionNotifier::WriteOrBufferWindowUpate( + QuicStreamId id, QuicStreamOffset byte_offset) { + QUIC_DVLOG(1) << "Writing WINDOW_UPDATE"; + const bool had_buffered_data = + HasBufferedStreamData() || HasBufferedControlFrames(); + QuicControlFrameId control_frame_id = ++last_control_frame_id_; + control_frames_.emplace_back(( + QuicFrame(new QuicWindowUpdateFrame(control_frame_id, id, byte_offset)))); + if (had_buffered_data) { + QUIC_DLOG(WARNING) << "Connection is write blocked"; + return; + } + WriteBufferedControlFrames(); +} + void SimpleSessionNotifier::WriteOrBufferPing() { QUIC_DVLOG(1) << "Writing PING_FRAME"; const bool had_buffered_data = @@ -163,23 +177,20 @@ void SimpleSessionNotifier::NeuterUnencryptedData() { } void SimpleSessionNotifier::OnCanWrite() { - if (connection_->donot_write_mid_packet_processing()) { - if (connection_->framer().is_processing_packet()) { - // Do not write data in the middle of packet processing because rest - // frames in the packet may change the data to write. For example, lost - // data could be acknowledged. Also, connection is going to emit - // OnCanWrite signal post packet processing. - QUIC_BUG(simple_notifier_write_mid_packet_processing) - << "Try to write mid packet processing."; - return; - } + if (connection_->framer().is_processing_packet()) { + // Do not write data in the middle of packet processing because rest + // frames in the packet may change the data to write. For example, lost + // data could be acknowledged. Also, connection is going to emit + // OnCanWrite signal post packet processing. + QUIC_BUG(simple_notifier_write_mid_packet_processing) + << "Try to write mid packet processing."; + return; } if (!RetransmitLostCryptoData() || !RetransmitLostControlFrames() || !RetransmitLostStreamData()) { return; } - // Write buffered control frames. - if (!WriteBufferedControlFrames()) { + if (!WriteBufferedCryptoData() || !WriteBufferedControlFrames()) { return; } // Write new data. @@ -263,7 +274,7 @@ bool SimpleSessionNotifier::OnFrameAcked(const QuicFrame& frame, if (frame.type != STREAM_FRAME) { return OnControlFrameAcked(frame); } - if (!QuicContainsKey(stream_map_, frame.stream_frame.stream_id)) { + if (!stream_map_.contains(frame.stream_frame.stream_id)) { return false; } auto* state = &stream_map_.find(frame.stream_frame.stream_id)->second; @@ -304,7 +315,7 @@ void SimpleSessionNotifier::OnFrameLost(const QuicFrame& frame) { OnControlFrameLost(frame); return; } - if (!QuicContainsKey(stream_map_, frame.stream_frame.stream_id)) { + if (!stream_map_.contains(frame.stream_frame.stream_id)) { return; } auto* state = &stream_map_.find(frame.stream_frame.stream_id)->second; @@ -359,7 +370,7 @@ void SimpleSessionNotifier::RetransmitFrames(const QuicFrames& frames, } continue; } - if (!QuicContainsKey(stream_map_, frame.stream_frame.stream_id)) { + if (!stream_map_.contains(frame.stream_frame.stream_id)) { continue; } const auto& state = stream_map_.find(frame.stream_frame.stream_id)->second; @@ -437,7 +448,7 @@ bool SimpleSessionNotifier::IsFrameOutstanding(const QuicFrame& frame) const { if (frame.type != STREAM_FRAME) { return IsControlFrameOutstanding(frame); } - if (!QuicContainsKey(stream_map_, frame.stream_frame.stream_id)) { + if (!stream_map_.contains(frame.stream_frame.stream_id)) { return false; } const auto& state = stream_map_.find(frame.stream_frame.stream_id)->second; @@ -463,8 +474,8 @@ bool SimpleSessionNotifier::HasUnackedCryptoData() const { } return false; } - if (!QuicContainsKey(stream_map_, QuicUtils::GetCryptoStreamId( - connection_->transport_version()))) { + if (!stream_map_.contains( + QuicUtils::GetCryptoStreamId(connection_->transport_version()))) { return false; } const auto& state = @@ -521,7 +532,7 @@ void SimpleSessionNotifier::OnControlFrameLost(const QuicFrame& frame) { kInvalidControlFrameId) { return; } - if (!QuicContainsKey(lost_control_frames_, id)) { + if (!lost_control_frames_.contains(id)) { lost_control_frames_[id] = true; } } @@ -584,8 +595,8 @@ bool SimpleSessionNotifier::RetransmitLostCryptoData() { } return true; } - if (!QuicContainsKey(stream_map_, QuicUtils::GetCryptoStreamId( - connection_->transport_version()))) { + if (!stream_map_.contains( + QuicUtils::GetCryptoStreamId(connection_->transport_version()))) { return true; } auto& state = @@ -669,6 +680,26 @@ bool SimpleSessionNotifier::RetransmitLostStreamData() { return !HasLostStreamData(); } +bool SimpleSessionNotifier::WriteBufferedCryptoData() { + for (size_t i = 0; i < NUM_ENCRYPTION_LEVELS; ++i) { + const StreamState& state = crypto_state_[i]; + QuicIntervalSet buffered_crypto_data(0, + state.bytes_total); + buffered_crypto_data.Difference(crypto_bytes_transferred_[i]); + for (const auto& interval : buffered_crypto_data) { + size_t bytes_written = connection_->SendCryptoData( + static_cast(i), interval.Length(), interval.min()); + crypto_state_[i].bytes_sent += bytes_written; + crypto_bytes_transferred_[i].Add(interval.min(), + interval.min() + bytes_written); + if (bytes_written < interval.Length()) { + return false; + } + } + } + return true; +} + bool SimpleSessionNotifier::WriteBufferedControlFrames() { while (HasBufferedControlFrames()) { QuicFrame frame_to_send = @@ -701,7 +732,7 @@ bool SimpleSessionNotifier::HasBufferedStreamData() const { } bool SimpleSessionNotifier::StreamIsWaitingForAcks(QuicStreamId id) const { - if (!QuicContainsKey(stream_map_, id)) { + if (!stream_map_.contains(id)) { return false; } const StreamState& state = stream_map_.find(id)->second; @@ -710,7 +741,7 @@ bool SimpleSessionNotifier::StreamIsWaitingForAcks(QuicStreamId id) const { } bool SimpleSessionNotifier::StreamHasBufferedData(QuicStreamId id) const { - if (!QuicContainsKey(stream_map_, id)) { + if (!stream_map_.contains(id)) { return false; } const StreamState& state = stream_map_.find(id)->second; diff --git a/gquiche/quic/test_tools/simple_session_notifier.h b/gquiche/quic/test_tools/simple_session_notifier.h index 197daa22..f18decb2 100644 --- a/gquiche/quic/test_tools/simple_session_notifier.h +++ b/gquiche/quic/test_tools/simple_session_notifier.h @@ -6,10 +6,11 @@ #define QUICHE_QUIC_TEST_TOOLS_SIMPLE_SESSION_NOTIFIER_H_ #include "absl/container/flat_hash_map.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/core/quic_interval_set.h" #include "gquiche/quic/core/session_notifier_interface.h" #include "gquiche/quic/platform/api/quic_test.h" +#include "gquiche/common/quiche_circular_deque.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -33,6 +34,10 @@ class SimpleSessionNotifier : public SessionNotifierInterface { void WriteOrBufferRstStream(QuicStreamId id, QuicRstStreamErrorCode error, QuicStreamOffset bytes_written); + + // Tries to write WINDOW_UPDATE. + void WriteOrBufferWindowUpate(QuicStreamId id, QuicStreamOffset byte_offset); + // Tries to write PING. void WriteOrBufferPing(); @@ -126,15 +131,17 @@ class SimpleSessionNotifier : public SessionNotifierInterface { bool WriteBufferedControlFrames(); + bool WriteBufferedCryptoData(); + bool IsControlFrameOutstanding(const QuicFrame& frame) const; bool HasBufferedControlFrames() const; bool StreamHasBufferedData(QuicStreamId id) const; - QuicCircularDeque control_frames_; + quiche::QuicheCircularDeque control_frames_; - QuicLinkedHashMap lost_control_frames_; + quiche::QuicheLinkedHashMap lost_control_frames_; // Id of latest saved control frame. 0 if no control frame has been saved. QuicControlFrameId last_control_frame_id_; diff --git a/gquiche/quic/test_tools/simulator/link.h b/gquiche/quic/test_tools/simulator/link.h index 30a8d3fe..ec27a595 100644 --- a/gquiche/quic/test_tools/simulator/link.h +++ b/gquiche/quic/test_tools/simulator/link.h @@ -9,9 +9,9 @@ #include "gquiche/quic/core/crypto/quic_random.h" #include "gquiche/quic/core/quic_bandwidth.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/test_tools/simulator/actor.h" #include "gquiche/quic/test_tools/simulator/port.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { namespace simulator { @@ -61,7 +61,7 @@ class OneWayLink : public Actor, public ConstrainedPortInterface { void ScheduleNextPacketDeparture(); UnconstrainedPortInterface* sink_; - QuicCircularDeque packets_in_transit_; + quiche::QuicheCircularDeque packets_in_transit_; QuicBandwidth bandwidth_; const QuicTime::Delta propagation_delay_; diff --git a/gquiche/quic/test_tools/simulator/queue.cc b/gquiche/quic/test_tools/simulator/queue.cc index 49841cb1..cddd55d7 100644 --- a/gquiche/quic/test_tools/simulator/queue.cc +++ b/gquiche/quic/test_tools/simulator/queue.cc @@ -26,7 +26,7 @@ Queue::Queue(Simulator* simulator, std::string name, QuicByteCount capacity) new AggregationAlarmDelegate(this))); } -Queue::~Queue() {} +Queue::~Queue() { aggregation_timeout_alarm_->PermanentCancel(); } void Queue::set_tx_port(ConstrainedPortInterface* port) { tx_port_ = port; diff --git a/gquiche/quic/test_tools/simulator/queue.h b/gquiche/quic/test_tools/simulator/queue.h index 1f374832..d00bb3e6 100644 --- a/gquiche/quic/test_tools/simulator/queue.h +++ b/gquiche/quic/test_tools/simulator/queue.h @@ -6,8 +6,8 @@ #define QUICHE_QUIC_TEST_TOOLS_SIMULATOR_QUEUE_H_ #include "gquiche/quic/core/quic_alarm.h" -#include "gquiche/quic/core/quic_circular_deque.h" #include "gquiche/quic/test_tools/simulator/link.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { namespace simulator { @@ -73,7 +73,7 @@ class Queue : public Actor, public UnconstrainedPortInterface { }; // Alarm handler for aggregation timeout. - class AggregationAlarmDelegate : public QuicAlarm::Delegate { + class AggregationAlarmDelegate : public QuicAlarm::DelegateWithoutContext { public: explicit AggregationAlarmDelegate(Queue* queue); @@ -110,7 +110,7 @@ class Queue : public Actor, public UnconstrainedPortInterface { std::unique_ptr aggregation_timeout_alarm_; ConstrainedPortInterface* tx_port_; - QuicCircularDeque queue_; + quiche::QuicheCircularDeque queue_; ListenerInterface* listener_; }; diff --git a/gquiche/quic/test_tools/simulator/quic_endpoint.h b/gquiche/quic/test_tools/simulator/quic_endpoint.h index a3e32702..181d33bc 100644 --- a/gquiche/quic/test_tools/simulator/quic_endpoint.h +++ b/gquiche/quic/test_tools/simulator/quic_endpoint.h @@ -110,10 +110,8 @@ class QuicEndpoint : public QuicEndpointBase, return nullptr; } void BeforeConnectionCloseSent() override {} - bool ValidateToken(absl::string_view /*token*/) const override { - return true; - } - void MaybeSendAddressToken() override {} + bool ValidateToken(absl::string_view /*token*/) override { return true; } + bool MaybeSendAddressToken() override { return false; } bool IsKnownServerAddress( const QuicSocketAddress& /*address*/) const override { return false; diff --git a/gquiche/quic/test_tools/simulator/simulator.h b/gquiche/quic/test_tools/simulator/simulator.h index d9b6318f..1d38d84c 100644 --- a/gquiche/quic/test_tools/simulator/simulator.h +++ b/gquiche/quic/test_tools/simulator/simulator.h @@ -88,7 +88,7 @@ class Simulator : public QuicConnectionHelperInterface { }; // The delegate used for RunFor(). - class RunForDelegate : public QuicAlarm::Delegate { + class RunForDelegate : public QuicAlarm::DelegateWithoutContext { public: explicit RunForDelegate(bool* run_for_should_stop); void OnAlarm() override; diff --git a/gquiche/quic/test_tools/simulator/simulator_test.cc b/gquiche/quic/test_tools/simulator/simulator_test.cc index 3f3359ed..7e5525f7 100644 --- a/gquiche/quic/test_tools/simulator/simulator_test.cc +++ b/gquiche/quic/test_tools/simulator/simulator_test.cc @@ -475,7 +475,7 @@ class AlarmToggler : public Actor { }; // Counts the number of times an alarm has fired. -class CounterDelegate : public QuicAlarm::Delegate { +class CounterDelegate : public QuicAlarm::DelegateWithoutContext { public: explicit CounterDelegate(size_t* counter) : counter_(counter) {} diff --git a/gquiche/quic/test_tools/simulator/switch.h b/gquiche/quic/test_tools/simulator/switch.h index 84b94829..c19d7d92 100644 --- a/gquiche/quic/test_tools/simulator/switch.h +++ b/gquiche/quic/test_tools/simulator/switch.h @@ -78,7 +78,7 @@ class Switch { void DispatchPacket(SwitchPortNumber port_number, std::unique_ptr packet); - // This can not be a QuicCircularDeque since pointers into this are + // This cannot be a quiche::QuicheCircularDeque since pointers into this are // assumed to be stable. std::deque ports_; absl::flat_hash_map switching_table_; diff --git a/gquiche/quic/test_tools/test_ticket_crypter.cc b/gquiche/quic/test_tools/test_ticket_crypter.cc index 8bb8679e..d48158d6 100644 --- a/gquiche/quic/test_tools/test_ticket_crypter.cc +++ b/gquiche/quic/test_tools/test_ticket_crypter.cc @@ -37,7 +37,8 @@ size_t TestTicketCrypter::MaxOverhead() { return ticket_prefix_.size(); } -std::vector TestTicketCrypter::Encrypt(absl::string_view in) { +std::vector TestTicketCrypter::Encrypt( + absl::string_view in, absl::string_view /* encryption_key */) { size_t prefix_len = ticket_prefix_.size(); std::vector out(prefix_len + in.size()); memcpy(out.data(), ticket_prefix_.data(), prefix_len); diff --git a/gquiche/quic/test_tools/test_ticket_crypter.h b/gquiche/quic/test_tools/test_ticket_crypter.h index eff7d073..0a6982ee 100644 --- a/gquiche/quic/test_tools/test_ticket_crypter.h +++ b/gquiche/quic/test_tools/test_ticket_crypter.h @@ -19,7 +19,8 @@ class TestTicketCrypter : public ProofSource::TicketCrypter { // TicketCrypter interface size_t MaxOverhead() override; - std::vector Encrypt(absl::string_view in) override; + std::vector Encrypt(absl::string_view in, + absl::string_view encryption_key) override; void Decrypt(absl::string_view in, std::unique_ptr callback) override; diff --git a/gquiche/quic/test_tools/web_transport_resets_backend.cc b/gquiche/quic/test_tools/web_transport_resets_backend.cc new file mode 100644 index 00000000..3375936c --- /dev/null +++ b/gquiche/quic/test_tools/web_transport_resets_backend.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "gquiche/quic/test_tools/web_transport_resets_backend.h" + +#include + +#include "gquiche/quic/core/web_transport_interface.h" +#include "gquiche/quic/tools/web_transport_test_visitors.h" +#include "gquiche/common/quiche_circular_deque.h" + +namespace quic { +namespace test { + +namespace { + +class ResetsVisitor; + +class BidirectionalEchoVisitorWithLogging + : public WebTransportBidirectionalEchoVisitor { + public: + BidirectionalEchoVisitorWithLogging(WebTransportStream* stream, + ResetsVisitor* session_visitor) + : WebTransportBidirectionalEchoVisitor(stream), + session_visitor_(session_visitor) {} + + void OnResetStreamReceived(WebTransportStreamError error) override; + void OnStopSendingReceived(WebTransportStreamError error) override; + + private: + ResetsVisitor* session_visitor_; // Not owned. +}; + +class ResetsVisitor : public WebTransportVisitor { + public: + ResetsVisitor(WebTransportSession* session) : session_(session) {} + + void OnSessionReady(const spdy::SpdyHeaderBlock& /*headers*/) override {} + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} + + void OnIncomingBidirectionalStreamAvailable() override { + while (true) { + WebTransportStream* stream = + session_->AcceptIncomingBidirectionalStream(); + if (stream == nullptr) { + return; + } + stream->SetVisitor( + std::make_unique(stream, this)); + stream->visitor()->OnCanRead(); + } + } + void OnIncomingUnidirectionalStreamAvailable() override {} + + void OnDatagramReceived(absl::string_view /*datagram*/) override {} + + void OnCanCreateNewOutgoingBidirectionalStream() override {} + void OnCanCreateNewOutgoingUnidirectionalStream() override { + MaybeSendLogsBack(); + } + + void Log(std::string line) { + log_.push_back(std::move(line)); + MaybeSendLogsBack(); + } + + private: + void MaybeSendLogsBack() { + while (!log_.empty() && + session_->CanOpenNextOutgoingUnidirectionalStream()) { + WebTransportStream* stream = session_->OpenOutgoingUnidirectionalStream(); + stream->SetVisitor( + std::make_unique( + stream, log_.front())); + log_.pop_front(); + stream->visitor()->OnCanWrite(); + } + } + + WebTransportSession* session_; // Not owned. + quiche::QuicheCircularDeque log_; +}; + +void BidirectionalEchoVisitorWithLogging::OnResetStreamReceived( + WebTransportStreamError error) { + session_visitor_->Log(absl::StrCat("Received reset for stream ", + stream()->GetStreamId(), + " with error code ", error)); + WebTransportBidirectionalEchoVisitor::OnResetStreamReceived(error); +} +void BidirectionalEchoVisitorWithLogging::OnStopSendingReceived( + WebTransportStreamError error) { + session_visitor_->Log(absl::StrCat("Received stop sending for stream ", + stream()->GetStreamId(), + " with error code ", error)); + WebTransportBidirectionalEchoVisitor::OnStopSendingReceived(error); +} + +} // namespace + +QuicSimpleServerBackend::WebTransportResponse WebTransportResetsBackend( + const spdy::Http2HeaderBlock& /*request_headers*/, + WebTransportSession* session) { + QuicSimpleServerBackend::WebTransportResponse response; + response.response_headers[":status"] = "200"; + response.visitor = std::make_unique(session); + return response; +} + +} // namespace test +} // namespace quic diff --git a/gquiche/quic/test_tools/web_transport_resets_backend.h b/gquiche/quic/test_tools/web_transport_resets_backend.h new file mode 100644 index 00000000..65177476 --- /dev/null +++ b/gquiche/quic/test_tools/web_transport_resets_backend.h @@ -0,0 +1,23 @@ +// Copyright (c) 2021 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_RESETS_BACKEND_H_ +#define QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_RESETS_BACKEND_H_ + +#include "gquiche/quic/test_tools/quic_test_backend.h" + +namespace quic { +namespace test { + +// A backend for testing RESET_STREAM/STOP_SENDING behavior. Provides +// bidirectional echo streams; whenever one of those receives RESET_STREAM or +// STOP_SENDING, a log message is sent as a unidirectional stream. +QuicSimpleServerBackend::WebTransportResponse WebTransportResetsBackend( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session); + +} // namespace test +} // namespace quic + +#endif // QUICHE_QUIC_TEST_TOOLS_WEB_TRANSPORT_RESETS_BACKEND_H_ diff --git a/gquiche/quic/tools/quic_backend_response.h b/gquiche/quic/tools/quic_backend_response.h index 40bee13c..3c09e929 100644 --- a/gquiche/quic/tools/quic_backend_response.h +++ b/gquiche/quic/tools/quic_backend_response.h @@ -17,6 +17,7 @@ class QuicBackendResponse { public: // A ServerPushInfo contains path of the push request and everything needed in // comprising a response for the push request. + // TODO(b/171463363): Remove. struct ServerPushInfo { ServerPushInfo(QuicUrl request_url, spdy::Http2HeaderBlock headers, diff --git a/gquiche/quic/tools/quic_client.cc b/gquiche/quic/tools/quic_client.cc index 1e5eee54..3d76a7e3 100644 --- a/gquiche/quic/tools/quic_client.cc +++ b/gquiche/quic/tools/quic_client.cc @@ -164,7 +164,8 @@ std::unique_ptr QuicClient::CreateQuicClientSession( QuicConnection* connection) { return std::make_unique( *config(), supported_versions, connection, server_id(), crypto_config(), - push_promise_index(), drop_response_body(), enable_web_transport()); + push_promise_index(), drop_response_body(), enable_web_transport(), + use_datagram_contexts()); } QuicClientEpollNetworkHelper* QuicClient::epoll_network_helper() { diff --git a/gquiche/quic/tools/quic_client_base.cc b/gquiche/quic/tools/quic_client_base.cc index 489fe4c0..1f09fa29 100644 --- a/gquiche/quic/tools/quic_client_base.cc +++ b/gquiche/quic/tools/quic_client_base.cc @@ -3,6 +3,8 @@ // found in the LICENSE file. #include "gquiche/quic/tools/quic_client_base.h" + +#include #include #include "gquiche/quic/core/crypto/quic_random.h" @@ -64,6 +66,7 @@ class QuicClientSocketMigrationValidationResultDelegate std::unique_ptr context) override { QUIC_LOG(WARNING) << "Fail to validate path " << *context << ", stop migrating."; + client_->session()->connection()->OnPathValidationFailureAtClient(); } private: @@ -182,6 +185,9 @@ void QuicClientBase::StartConnect() { server_address(), helper(), alarm_factory(), writer, /* owns_writer= */ false, Perspective::IS_CLIENT, client_supported_versions)); + if (can_reconnect_with_different_version) { + session()->set_client_original_supported_versions(supported_versions()); + } if (connection_debug_visitor_ != nullptr) { session()->connection()->set_debug_visitor(connection_debug_visitor_); } @@ -257,6 +263,8 @@ bool QuicClientBase::MigrateSocketWithSpecifiedPort( const QuicIpAddress& new_host, int port) { if (!connected()) { + QUICHE_DVLOG(1) + << "MigrateSocketWithSpecifiedPort failed as connection has closed"; return false; } @@ -264,11 +272,17 @@ bool QuicClientBase::MigrateSocketWithSpecifiedPort( std::unique_ptr writer = CreateWriterForNewNetwork(new_host, port); if (writer == nullptr) { + QUICHE_DVLOG(1) + << "MigrateSocketWithSpecifiedPort failed from writer creation"; + return false; + } + if (!session()->MigratePath(network_helper_->GetLatestClientAddress(), + session()->connection()->peer_address(), + writer.get(), false)) { + QUICHE_DVLOG(1) + << "MigrateSocketWithSpecifiedPort failed from session()->MigratePath"; return false; } - session()->MigratePath(network_helper_->GetLatestClientAddress(), - session()->connection()->peer_address(), writer.get(), - false); set_writer(writer.release()); return true; } @@ -433,13 +447,20 @@ QuicConnectionId QuicClientBase::GetClientConnectionId() { bool QuicClientBase::CanReconnectWithDifferentVersion( ParsedQuicVersion* version) const { if (session_ == nullptr || session_->connection() == nullptr || - session_->error() != QUIC_INVALID_VERSION || - session_->connection()->server_supported_versions().empty()) { + session_->error() != QUIC_INVALID_VERSION) { return false; } + + const auto& server_supported_versions = + session_->connection()->server_supported_versions(); + if (server_supported_versions.empty()) { + return false; + } + for (const auto& client_version : supported_versions_) { - if (QuicContainsValue(session_->connection()->server_supported_versions(), - client_version)) { + if (std::find(server_supported_versions.begin(), + server_supported_versions.end(), + client_version) != server_supported_versions.end()) { *version = client_version; return true; } diff --git a/gquiche/quic/tools/quic_client_base.h b/gquiche/quic/tools/quic_client_base.h index c15bbad8..7d63b9f0 100644 --- a/gquiche/quic/tools/quic_client_base.h +++ b/gquiche/quic/tools/quic_client_base.h @@ -145,6 +145,11 @@ class QuicClientBase { crypto_config_.set_user_agent_id(user_agent_id); } + void SetTlsSignatureAlgorithms(std::string signature_algorithms) { + crypto_config_.set_tls_signature_algorithms( + std::move(signature_algorithms)); + } + const ParsedQuicVersionVector& supported_versions() const { return supported_versions_; } diff --git a/gquiche/quic/tools/quic_client_epoll_network_helper.h b/gquiche/quic/tools/quic_client_epoll_network_helper.h index 33d84eb6..1b52412a 100644 --- a/gquiche/quic/tools/quic_client_epoll_network_helper.h +++ b/gquiche/quic/tools/quic_client_epoll_network_helper.h @@ -15,9 +15,9 @@ #include "gquiche/quic/core/http/quic_client_push_promise_index.h" #include "gquiche/quic/core/quic_config.h" #include "gquiche/quic/core/quic_packet_reader.h" -#include "gquiche/quic/platform/api/quic_containers.h" #include "gquiche/quic/platform/api/quic_epoll.h" #include "gquiche/quic/tools/quic_client_base.h" +#include "gquiche/common/quiche_linked_hash_map.h" namespace quic { @@ -73,7 +73,8 @@ class QuicClientEpollNetworkHelper : public QuicClientBase::NetworkHelper, QuicEpollServer* epoll_server() { return epoll_server_; } - const QuicLinkedHashMap& fd_address_map() const { + const quiche::QuicheLinkedHashMap& fd_address_map() + const { return fd_address_map_; } @@ -110,7 +111,7 @@ class QuicClientEpollNetworkHelper : public QuicClientBase::NetworkHelper, // Map mapping created UDP sockets to their addresses. By using linked hash // map, the order of socket creation can be recorded. - QuicLinkedHashMap fd_address_map_; + quiche::QuicheLinkedHashMap fd_address_map_; // If overflow_supported_ is true, this will be the number of packets dropped // during the lifetime of the server. diff --git a/gquiche/quic/tools/quic_client_interop_test_bin.cc b/gquiche/quic/tools/quic_client_interop_test_bin.cc index 0c53db6c..c1ce6da8 100644 --- a/gquiche/quic/tools/quic_client_interop_test_bin.cc +++ b/gquiche/quic/tools/quic_client_interop_test_bin.cc @@ -8,14 +8,14 @@ #include #include "absl/strings/str_cat.h" +#include "gquiche/quic/core/crypto/quic_client_session_cache.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_versions.h" #include "gquiche/quic/platform/api/quic_epoll.h" #include "gquiche/quic/platform/api/quic_system_event_loop.h" -#include "platform/quic_epoll_clock.h" +#include "platform/quic_platform_impl/quic_epoll_clock.h" #include "gquiche/quic/test_tools/quic_connection_peer.h" #include "gquiche/quic/test_tools/quic_session_peer.h" -#include "gquiche/quic/test_tools/simple_session_cache.h" #include "gquiche/quic/tools/fake_proof_verifier.h" #include "gquiche/quic/tools/quic_client.h" #include "gquiche/quic/tools/quic_url.h" @@ -210,7 +210,7 @@ void QuicClientInteropRunner::AttemptRequest(QuicSocketAddress addr, } auto proof_verifier = std::make_unique(); - auto session_cache = std::make_unique(); + auto session_cache = std::make_unique(); QuicEpollServer epoll_server; QuicEpollClock epoll_clock(&epoll_server); QuicConfig config; diff --git a/gquiche/quic/tools/quic_client_test.cc b/gquiche/quic/tools/quic_client_test.cc index e65f0e19..67b8474e 100644 --- a/gquiche/quic/tools/quic_client_test.cc +++ b/gquiche/quic/tools/quic_client_test.cc @@ -18,7 +18,7 @@ #include "gquiche/quic/platform/api/quic_test_loopback.h" #include "gquiche/quic/test_tools/crypto_test_utils.h" #include "gquiche/quic/test_tools/quic_client_peer.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace quic { namespace test { diff --git a/gquiche/quic/tools/quic_epoll_client_factory.cc b/gquiche/quic/tools/quic_epoll_client_factory.cc index 4d38a725..9ee743b4 100644 --- a/gquiche/quic/tools/quic_epoll_client_factory.cc +++ b/gquiche/quic/tools/quic_epoll_client_factory.cc @@ -24,7 +24,8 @@ std::unique_ptr QuicEpollClientFactory::CreateClient( uint16_t port, ParsedQuicVersionVector versions, const QuicConfig& config, - std::unique_ptr verifier) { + std::unique_ptr verifier, + std::unique_ptr session_cache) { QuicSocketAddress addr = tools::LookupAddress( address_family_for_lookup, host_for_lookup, absl::StrCat(port)); if (!addr.IsInitialized()) { @@ -34,7 +35,7 @@ std::unique_ptr QuicEpollClientFactory::CreateClient( QuicServerId server_id(host_for_handshake, port, false); return std::make_unique(addr, server_id, versions, config, &epoll_server_, std::move(verifier), - nullptr); + std::move(session_cache)); } } // namespace quic diff --git a/gquiche/quic/tools/quic_epoll_client_factory.h b/gquiche/quic/tools/quic_epoll_client_factory.h index a93823a9..2387b8d1 100644 --- a/gquiche/quic/tools/quic_epoll_client_factory.h +++ b/gquiche/quic/tools/quic_epoll_client_factory.h @@ -20,7 +20,8 @@ class QuicEpollClientFactory : public QuicToyClient::ClientFactory { uint16_t port, ParsedQuicVersionVector versions, const QuicConfig& config, - std::unique_ptr verifier) override; + std::unique_ptr verifier, + std::unique_ptr session_cache) override; private: QuicEpollServer epoll_server_; diff --git a/gquiche/quic/tools/quic_memory_cache_backend.cc b/gquiche/quic/tools/quic_memory_cache_backend.cc index 49c2438a..97934c02 100644 --- a/gquiche/quic/tools/quic_memory_cache_backend.cc +++ b/gquiche/quic/tools/quic_memory_cache_backend.cc @@ -12,10 +12,10 @@ #include "absl/strings/string_view.h" #include "gquiche/quic/core/http/spdy_utils.h" #include "gquiche/quic/platform/api/quic_bug_tracker.h" -#include "gquiche/quic/platform/api/quic_file_utils.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/quic/tools/web_transport_test_visitors.h" +#include "gquiche/common/platform/api/quiche_file_utils.h" +#include "gquiche/common/quiche_text_utils.h" using spdy::Http2HeaderBlock; using spdy::kV3LowestPriority; @@ -28,12 +28,19 @@ QuicMemoryCacheBackend::ResourceFile::ResourceFile(const std::string& file_name) QuicMemoryCacheBackend::ResourceFile::~ResourceFile() = default; void QuicMemoryCacheBackend::ResourceFile::Read() { - ReadFileContents(file_name_, &file_contents_); + absl::optional maybe_file_contents = + quiche::ReadFileContents(file_name_); + if (!maybe_file_contents) { + QUIC_LOG(DFATAL) << "Failed to read file for the memory cache backend: " + << file_name_; + return; + } + file_contents_ = *maybe_file_contents; // First read the headers. size_t start = 0; while (start < file_contents_.length()) { - size_t pos = file_contents_.find("\n", start); + size_t pos = file_contents_.find('\n', start); if (pos == std::string::npos) { QUIC_LOG(DFATAL) << "Headers invalid or empty, ignoring: " << file_name_; return; @@ -51,7 +58,7 @@ void QuicMemoryCacheBackend::ResourceFile::Read() { } // Extract the status from the HTTP first line. if (line.substr(0, 4) == "HTTP") { - pos = line.find(" "); + pos = line.find(' '); if (pos == std::string::npos) { QUIC_LOG(DFATAL) << "Headers invalid or empty, ignoring: " << file_name_; @@ -139,8 +146,7 @@ void QuicMemoryCacheBackend::ResourceFile::HandleXOriginalUrl() { } const QuicBackendResponse* QuicMemoryCacheBackend::GetResponse( - absl::string_view host, - absl::string_view path) const { + absl::string_view host, absl::string_view path) const { QuicWriterMutexLock lock(&response_mutex_); auto it = responses_.find(GetKey(host, path)); @@ -178,11 +184,8 @@ void QuicMemoryCacheBackend::AddSimpleResponse(absl::string_view host, } void QuicMemoryCacheBackend::AddSimpleResponseWithServerPushResources( - absl::string_view host, - absl::string_view path, - int response_code, - absl::string_view body, - std::list push_resources) { + absl::string_view host, absl::string_view path, int response_code, + absl::string_view body, std::list push_resources) { AddSimpleResponse(host, path, response_code, body); MaybeAddServerPushResources(host, path, push_resources); } @@ -213,10 +216,8 @@ void QuicMemoryCacheBackend::AddResponse(absl::string_view host, } void QuicMemoryCacheBackend::AddResponseWithEarlyHints( - absl::string_view host, - absl::string_view path, - spdy::Http2HeaderBlock response_headers, - absl::string_view response_body, + absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, absl::string_view response_body, const std::vector& early_hints) { AddResponseImpl(host, path, QuicBackendResponse::REGULAR_RESPONSE, std::move(response_headers), response_body, @@ -224,18 +225,15 @@ void QuicMemoryCacheBackend::AddResponseWithEarlyHints( } void QuicMemoryCacheBackend::AddSpecialResponse( - absl::string_view host, - absl::string_view path, + absl::string_view host, absl::string_view path, SpecialResponseType response_type) { AddResponseImpl(host, path, response_type, Http2HeaderBlock(), "", Http2HeaderBlock(), std::vector()); } void QuicMemoryCacheBackend::AddSpecialResponse( - absl::string_view host, - absl::string_view path, - spdy::Http2HeaderBlock response_headers, - absl::string_view response_body, + absl::string_view host, absl::string_view path, + spdy::Http2HeaderBlock response_headers, absl::string_view response_body, SpecialResponseType response_type) { AddResponseImpl(host, path, response_type, std::move(response_headers), response_body, Http2HeaderBlock(), @@ -253,7 +251,12 @@ bool QuicMemoryCacheBackend::InitializeBackend( QUIC_LOG(INFO) << "Attempting to initialize QuicMemoryCacheBackend from directory: " << cache_directory; - std::vector files = ReadFileContents(cache_directory); + std::vector files; + if (!quiche::EnumerateDirectoryRecursively(cache_directory, files)) { + QUIC_BUG(QuicMemoryCacheBackend unreadable directory) + << "Can't read QuicMemoryCacheBackend directory: " << cache_directory; + return false; + } std::list> resource_files; for (const auto& filename : files) { std::unique_ptr resource_file(new ResourceFile(filename)); @@ -313,6 +316,10 @@ void QuicMemoryCacheBackend::GenerateDynamicResponses() { QuicBackendResponse::GENERATE_BYTES); } +void QuicMemoryCacheBackend::EnableWebTransport() { + enable_webtransport_ = true; +} + bool QuicMemoryCacheBackend::IsBackendInitialized() const { return cache_initialized_; } @@ -336,11 +343,10 @@ void QuicMemoryCacheBackend::FetchResponseFromBackend( if (path != request_headers.end()) { request_url += std::string(path->second); } - std::list resources = GetServerPushResources(request_url); QUIC_DVLOG(1) << "Fetching QUIC response from backend in-memory cache for url " << request_url; - quic_stream->OnResponseBackendComplete(quic_response, resources); + quic_stream->OnResponseBackendComplete(quic_response); } // The memory cache does not have a per-stream handler @@ -361,6 +367,35 @@ std::list QuicMemoryCacheBackend::GetServerPushResources( return resources; } +QuicMemoryCacheBackend::WebTransportResponse +QuicMemoryCacheBackend::ProcessWebTransportRequest( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session) { + if (!SupportsWebTransport()) { + return QuicSimpleServerBackend::ProcessWebTransportRequest(request_headers, + session); + } + + auto path_it = request_headers.find(":path"); + if (path_it == request_headers.end()) { + WebTransportResponse response; + response.response_headers[":status"] = "400"; + return response; + } + absl::string_view path = path_it->second; + if (path == "/echo") { + WebTransportResponse response; + response.response_headers[":status"] = "200"; + response.visitor = + std::make_unique(session); + return response; + } + + WebTransportResponse response; + response.response_headers[":status"] = "404"; + return response; +} + QuicMemoryCacheBackend::~QuicMemoryCacheBackend() { { QuicWriterMutexLock lock(&response_mutex_); @@ -369,19 +404,16 @@ QuicMemoryCacheBackend::~QuicMemoryCacheBackend() { } void QuicMemoryCacheBackend::AddResponseImpl( - absl::string_view host, - absl::string_view path, - SpecialResponseType response_type, - Http2HeaderBlock response_headers, - absl::string_view response_body, - Http2HeaderBlock response_trailers, + absl::string_view host, absl::string_view path, + SpecialResponseType response_type, Http2HeaderBlock response_headers, + absl::string_view response_body, Http2HeaderBlock response_trailers, const std::vector& early_hints) { QuicWriterMutexLock lock(&response_mutex_); QUICHE_DCHECK(!host.empty()) << "Host must be populated, e.g. \"www.google.com\""; std::string key = GetKey(host, path); - if (QuicContainsKey(responses_, key)) { + if (responses_.contains(key)) { QUIC_BUG(quic_bug_10932_3) << "Response for '" << key << "' already exists!"; return; @@ -408,8 +440,7 @@ std::string QuicMemoryCacheBackend::GetKey(absl::string_view host, } void QuicMemoryCacheBackend::MaybeAddServerPushResources( - absl::string_view request_host, - absl::string_view request_path, + absl::string_view request_host, absl::string_view request_path, std::list push_resources) { std::string request_url = GetKey(request_host, request_path); @@ -435,7 +466,7 @@ void QuicMemoryCacheBackend::MaybeAddServerPushResources( bool found_existing_response = false; { QuicWriterMutexLock lock(&response_mutex_); - found_existing_response = QuicContainsKey(responses_, GetKey(host, path)); + found_existing_response = responses_.contains(GetKey(host, path)); } if (!found_existing_response) { // Add a server push response to responses map, if it is not in the map. @@ -448,8 +479,7 @@ void QuicMemoryCacheBackend::MaybeAddServerPushResources( } bool QuicMemoryCacheBackend::PushResourceExistsInCache( - std::string original_request_url, - ServerPushInfo resource) { + std::string original_request_url, ServerPushInfo resource) { QuicWriterMutexLock lock(&response_mutex_); auto resource_range = server_push_resources_.equal_range(original_request_url); diff --git a/gquiche/quic/tools/quic_memory_cache_backend.h b/gquiche/quic/tools/quic_memory_cache_backend.h index 2c8785e2..224af249 100644 --- a/gquiche/quic/tools/quic_memory_cache_backend.h +++ b/gquiche/quic/tools/quic_memory_cache_backend.h @@ -90,6 +90,7 @@ class QuicMemoryCacheBackend : public QuicSimpleServerBackend { // some server push resources(resource path, corresponding response status and // path) associated with it. // Push resource implicitly come from the same host. + // TODO(b/171463363): Remove. void AddSimpleResponseWithServerPushResources( absl::string_view host, absl::string_view path, @@ -139,7 +140,10 @@ class QuicMemoryCacheBackend : public QuicSimpleServerBackend { // generated response of that many bytes. void GenerateDynamicResponses(); + void EnableWebTransport(); + // Find all the server push resources associated with |request_url|. + // TODO(b/171463363): Remove. std::list GetServerPushResources( std::string request_url); @@ -153,6 +157,10 @@ class QuicMemoryCacheBackend : public QuicSimpleServerBackend { QuicSimpleServerBackend::RequestHandler* quic_server_stream) override; void CloseBackendResponseStream( QuicSimpleServerBackend::RequestHandler* quic_server_stream) override; + WebTransportResponse ProcessWebTransportRequest( + const spdy::Http2HeaderBlock& request_headers, + WebTransportSession* session) override; + bool SupportsWebTransport() override { return enable_webtransport_; } private: void AddResponseImpl(absl::string_view host, @@ -167,6 +175,7 @@ class QuicMemoryCacheBackend : public QuicSimpleServerBackend { // Add some server push urls with given responses for specified // request if these push resources are not associated with this request yet. + // TODO(b/171463363): Remove. void MaybeAddServerPushResources( absl::string_view request_host, absl::string_view request_path, @@ -174,6 +183,7 @@ class QuicMemoryCacheBackend : public QuicSimpleServerBackend { // Check if push resource(push_host/push_path) associated with given request // url already exists in server push map. + // TODO(b/171463363): Remove. bool PushResourceExistsInCache(std::string original_request_url, QuicBackendResponse::ServerPushInfo resource); @@ -190,6 +200,7 @@ class QuicMemoryCacheBackend : public QuicSimpleServerBackend { QUIC_GUARDED_BY(response_mutex_); // A map from request URL to associated server push responses (if any). + // TODO(b/171463363): Remove. std::multimap server_push_resources_ QUIC_GUARDED_BY(response_mutex_); @@ -197,6 +208,8 @@ class QuicMemoryCacheBackend : public QuicSimpleServerBackend { // server threads accessing those responses. mutable QuicMutex response_mutex_; bool cache_initialized_; + + bool enable_webtransport_ = false; }; } // namespace quic diff --git a/gquiche/quic/tools/quic_memory_cache_backend_test.cc b/gquiche/quic/tools/quic_memory_cache_backend_test.cc index 981cb070..265f78c2 100644 --- a/gquiche/quic/tools/quic_memory_cache_backend_test.cc +++ b/gquiche/quic/tools/quic_memory_cache_backend_test.cc @@ -4,12 +4,13 @@ #include "gquiche/quic/tools/quic_memory_cache_backend.h" +#include + #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "gquiche/quic/platform/api/quic_file_utils.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/platform/api/quic_test.h" #include "gquiche/quic/tools/quic_backend_response.h" +#include "gquiche/common/platform/api/quiche_file_utils.h" namespace quic { namespace test { @@ -21,8 +22,7 @@ using ServerPushInfo = QuicBackendResponse::ServerPushInfo; class QuicMemoryCacheBackendTest : public QuicTest { protected: - void CreateRequest(std::string host, - std::string path, + void CreateRequest(std::string host, std::string path, spdy::Http2HeaderBlock* headers) { (*headers)[":method"] = "GET"; (*headers)[":path"] = path; @@ -49,7 +49,7 @@ TEST_F(QuicMemoryCacheBackendTest, AddSimpleResponseGetResponse) { CreateRequest("www.google.com", "/", &request_headers); const Response* response = cache_.GetResponse("www.google.com", "/"); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ("200", response->headers().find(":status")->second); EXPECT_EQ(response_body.size(), response->body().length()); } @@ -77,51 +77,87 @@ TEST_F(QuicMemoryCacheBackendTest, AddResponse) { EXPECT_EQ(response->trailers(), response_trailers); } -TEST_F(QuicMemoryCacheBackendTest, ReadsCacheDir) { +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_ReadsCacheDir DISABLED_ReadsCacheDir +#else +#define MAYBE_ReadsCacheDir ReadsCacheDir +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_ReadsCacheDir) { cache_.InitializeBackend(CacheDirectory()); const Response* response = cache_.GetResponse("test.example.com", "/index.html"); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ("200", response->headers().find(":status")->second); // Connection headers are not valid in HTTP/2. - EXPECT_FALSE(QuicContainsKey(response->headers(), "connection")); + EXPECT_FALSE(response->headers().contains("connection")); EXPECT_LT(0U, response->body().length()); } -TEST_F(QuicMemoryCacheBackendTest, ReadsCacheDirWithServerPushResource) { +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_ReadsCacheDirWithServerPushResource \ + DISABLED_ReadsCacheDirWithServerPushResource +#else +#define MAYBE_ReadsCacheDirWithServerPushResource \ + ReadsCacheDirWithServerPushResource +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_ReadsCacheDirWithServerPushResource) { cache_.InitializeBackend(CacheDirectory() + "_with_push"); std::list resources = cache_.GetServerPushResources("test.example.com/"); ASSERT_EQ(1UL, resources.size()); } -TEST_F(QuicMemoryCacheBackendTest, ReadsCacheDirWithServerPushResources) { +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_ReadsCacheDirWithServerPushResources \ + DISABLED_ReadsCacheDirWithServerPushResources +#else +#define MAYBE_ReadsCacheDirWithServerPushResources \ + ReadsCacheDirWithServerPushResources +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_ReadsCacheDirWithServerPushResources) { cache_.InitializeBackend(CacheDirectory() + "_with_push"); std::list resources = cache_.GetServerPushResources("test.example.com/index2.html"); ASSERT_EQ(2UL, resources.size()); } -TEST_F(QuicMemoryCacheBackendTest, UsesOriginalUrl) { +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_UsesOriginalUrl DISABLED_UsesOriginalUrl +#else +#define MAYBE_UsesOriginalUrl UsesOriginalUrl +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_UsesOriginalUrl) { cache_.InitializeBackend(CacheDirectory()); const Response* response = cache_.GetResponse("test.example.com", "/site_map.html"); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ("200", response->headers().find(":status")->second); // Connection headers are not valid in HTTP/2. - EXPECT_FALSE(QuicContainsKey(response->headers(), "connection")); + EXPECT_FALSE(response->headers().contains("connection")); EXPECT_LT(0U, response->body().length()); } -TEST_F(QuicMemoryCacheBackendTest, UsesOriginalUrlOnly) { +// TODO(crbug.com/1249712) This test is failing on iOS. +#if defined(OS_IOS) +#define MAYBE_UsesOriginalUrlOnly DISABLED_UsesOriginalUrlOnly +#else +#define MAYBE_UsesOriginalUrlOnly UsesOriginalUrlOnly +#endif +TEST_F(QuicMemoryCacheBackendTest, MAYBE_UsesOriginalUrlOnly) { // Tests that if the URL cannot be inferred correctly from the path // because the directory does not include the hostname, that the // X-Original-Url header's value will be used. std::string dir; std::string path = "map.html"; - for (const std::string& file : ReadFileContents(CacheDirectory())) { + std::vector files; + ASSERT_TRUE(quiche::EnumerateDirectoryRecursively(CacheDirectory(), files)); + for (const std::string& file : files) { if (absl::EndsWithIgnoreCase(file, "map.html")) { dir = file; dir.erase(dir.length() - path.length() - 1); @@ -134,10 +170,10 @@ TEST_F(QuicMemoryCacheBackendTest, UsesOriginalUrlOnly) { const Response* response = cache_.GetResponse("test.example.com", "/site_map.html"); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ("200", response->headers().find(":status")->second); // Connection headers are not valid in HTTP/2. - EXPECT_FALSE(QuicContainsKey(response->headers(), "connection")); + EXPECT_FALSE(response->headers().contains("connection")); EXPECT_LT(0U, response->body().length()); } @@ -157,20 +193,20 @@ TEST_F(QuicMemoryCacheBackendTest, DefaultResponse) { // Now we should get the default response for the original request. response = cache_.GetResponse("www.google.com", "/"); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ("200", response->headers().find(":status")->second); // Now add a set response for / and make sure it is returned cache_.AddSimpleResponse("www.google.com", "/", 302, ""); response = cache_.GetResponse("www.google.com", "/"); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ("302", response->headers().find(":status")->second); // We should get the default response for other requests. response = cache_.GetResponse("www.google.com", "/asd"); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ("200", response->headers().find(":status")->second); } @@ -243,7 +279,7 @@ TEST_F(QuicMemoryCacheBackendTest, GetServerPushResourcesAndPushResponses) { std::string path = url.path(); const Response* response = cache_.GetResponse(host, path); ASSERT_TRUE(response); - ASSERT_TRUE(QuicContainsKey(response->headers(), ":status")); + ASSERT_TRUE(response->headers().contains(":status")); EXPECT_EQ(push_response_status[i++], response->headers().find(":status")->second); EXPECT_EQ(push_resource.body, response->body()); diff --git a/gquiche/quic/tools/quic_packet_printer_bin.cc b/gquiche/quic/tools/quic_packet_printer_bin.cc index 0f2660e3..6265de8a 100644 --- a/gquiche/quic/tools/quic_packet_printer_bin.cc +++ b/gquiche/quic/tools/quic_packet_printer_bin.cc @@ -34,7 +34,7 @@ #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" DEFINE_QUIC_COMMAND_LINE_FLAG(std::string, quic_version, diff --git a/gquiche/quic/tools/quic_reject_reason_decoder_bin.cc b/gquiche/quic/tools/quic_reject_reason_decoder_bin.cc index f2244869..9f415c08 100644 --- a/gquiche/quic/tools/quic_reject_reason_decoder_bin.cc +++ b/gquiche/quic/tools/quic_reject_reason_decoder_bin.cc @@ -11,7 +11,7 @@ #include "gquiche/quic/core/crypto/crypto_handshake.h" #include "gquiche/quic/core/crypto/crypto_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" using quic::CryptoUtils; using quic::HandshakeFailureReason; diff --git a/gquiche/quic/tools/quic_simple_client_session.cc b/gquiche/quic/tools/quic_simple_client_session.cc index d0e10f9a..548c6d2f 100644 --- a/gquiche/quic/tools/quic_simple_client_session.cc +++ b/gquiche/quic/tools/quic_simple_client_session.cc @@ -9,39 +9,27 @@ namespace quic { QuicSimpleClientSession::QuicSimpleClientSession( - const QuicConfig& config, - const ParsedQuicVersionVector& supported_versions, - QuicConnection* connection, - const QuicServerId& server_id, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, - QuicClientPushPromiseIndex* push_promise_index, - bool drop_response_body) - : QuicSimpleClientSession(config, - supported_versions, - connection, - server_id, - crypto_config, - push_promise_index, + QuicClientPushPromiseIndex* push_promise_index, bool drop_response_body) + : QuicSimpleClientSession(config, supported_versions, connection, server_id, + crypto_config, push_promise_index, drop_response_body, - /*enable_web_transport=*/false) {} + /*enable_web_transport=*/false, + /*use_datagram_contexts=*/false) {} QuicSimpleClientSession::QuicSimpleClientSession( - const QuicConfig& config, - const ParsedQuicVersionVector& supported_versions, - QuicConnection* connection, - const QuicServerId& server_id, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, - QuicClientPushPromiseIndex* push_promise_index, - bool drop_response_body, - bool enable_web_transport) - : QuicSpdyClientSession(config, - supported_versions, - connection, - server_id, - crypto_config, - push_promise_index), + QuicClientPushPromiseIndex* push_promise_index, bool drop_response_body, + bool enable_web_transport, bool use_datagram_contexts) + : QuicSpdyClientSession(config, supported_versions, connection, server_id, + crypto_config, push_promise_index), drop_response_body_(drop_response_body), - enable_web_transport_(enable_web_transport) {} + enable_web_transport_(enable_web_transport), + use_datagram_contexts_(use_datagram_contexts) {} std::unique_ptr QuicSimpleClientSession::CreateClientStream() { @@ -54,4 +42,13 @@ bool QuicSimpleClientSession::ShouldNegotiateWebTransport() { return enable_web_transport_; } +bool QuicSimpleClientSession::ShouldNegotiateDatagramContexts() { + return use_datagram_contexts_; +} + +HttpDatagramSupport QuicSimpleClientSession::LocalHttpDatagramSupport() { + return enable_web_transport_ ? HttpDatagramSupport::kDraft04 + : HttpDatagramSupport::kNone; +} + } // namespace quic diff --git a/gquiche/quic/tools/quic_simple_client_session.h b/gquiche/quic/tools/quic_simple_client_session.h index e396a2e7..38bd9c88 100644 --- a/gquiche/quic/tools/quic_simple_client_session.h +++ b/gquiche/quic/tools/quic_simple_client_session.h @@ -25,15 +25,18 @@ class QuicSimpleClientSession : public QuicSpdyClientSession { const QuicServerId& server_id, QuicCryptoClientConfig* crypto_config, QuicClientPushPromiseIndex* push_promise_index, - bool drop_response_body, - bool enable_web_transport); + bool drop_response_body, bool enable_web_transport, + bool use_datagram_contexts); std::unique_ptr CreateClientStream() override; bool ShouldNegotiateWebTransport() override; + bool ShouldNegotiateDatagramContexts() override; + HttpDatagramSupport LocalHttpDatagramSupport() override; private: const bool drop_response_body_; const bool enable_web_transport_; + const bool use_datagram_contexts_; }; } // namespace quic diff --git a/gquiche/quic/tools/quic_simple_dispatcher.cc b/gquiche/quic/tools/quic_simple_dispatcher.cc index 6ce9fe09..7eccda26 100644 --- a/gquiche/quic/tools/quic_simple_dispatcher.cc +++ b/gquiche/quic/tools/quic_simple_dispatcher.cc @@ -49,12 +49,10 @@ void QuicSimpleDispatcher::OnRstStreamReceived( } std::unique_ptr QuicSimpleDispatcher::CreateQuicSession( - QuicConnectionId connection_id, - const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view /*alpn*/, + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view /*alpn*/, const ParsedQuicVersion& version, - absl::string_view /*sni*/) { + const ParsedClientHello& /*parsed_chlo*/) { // The QuicServerSessionBase takes ownership of |connection| below. QuicConnection* connection = new QuicConnection(connection_id, self_address, peer_address, helper(), diff --git a/gquiche/quic/tools/quic_simple_dispatcher.h b/gquiche/quic/tools/quic_simple_dispatcher.h index 3b06859d..01827dac 100644 --- a/gquiche/quic/tools/quic_simple_dispatcher.h +++ b/gquiche/quic/tools/quic_simple_dispatcher.h @@ -32,12 +32,10 @@ class QuicSimpleDispatcher : public QuicDispatcher { protected: std::unique_ptr CreateQuicSession( - QuicConnectionId connection_id, - const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view alpn, + QuicConnectionId connection_id, const QuicSocketAddress& self_address, + const QuicSocketAddress& peer_address, absl::string_view alpn, const ParsedQuicVersion& version, - absl::string_view sni) override; + const ParsedClientHello& parsed_chlo) override; QuicSimpleServerBackend* server_backend() { return quic_simple_server_backend_; diff --git a/gquiche/quic/tools/quic_simple_server_backend.h b/gquiche/quic/tools/quic_simple_server_backend.h index 60279a1a..f6e3e166 100644 --- a/gquiche/quic/tools/quic_simple_server_backend.h +++ b/gquiche/quic/tools/quic_simple_server_backend.h @@ -32,8 +32,7 @@ class QuicSimpleServerBackend { // Called when the response is ready at the backend and can be send back to // the QUIC client. virtual void OnResponseBackendComplete( - const QuicBackendResponse* response, - std::list resources) = 0; + const QuicBackendResponse* response) = 0; }; struct WebTransportResponse { @@ -68,6 +67,8 @@ class QuicSimpleServerBackend { return response; } virtual bool SupportsWebTransport() { return false; } + virtual bool UsesDatagramContexts() { return false; } + virtual bool SupportsExtendedConnect() { return true; } }; } // namespace quic diff --git a/gquiche/quic/tools/quic_simple_server_session.cc b/gquiche/quic/tools/quic_simple_server_session.cc index 2bb76611..beb47882 100644 --- a/gquiche/quic/tools/quic_simple_server_session.cc +++ b/gquiche/quic/tools/quic_simple_server_session.cc @@ -10,6 +10,7 @@ #include "gquiche/quic/core/http/quic_server_initiated_spdy_stream.h" #include "gquiche/quic/core/http/quic_spdy_session.h" #include "gquiche/quic/core/quic_connection.h" +#include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" @@ -18,30 +19,21 @@ namespace quic { QuicSimpleServerSession::QuicSimpleServerSession( - const QuicConfig& config, - const ParsedQuicVersionVector& supported_versions, - QuicConnection* connection, - QuicSession::Visitor* visitor, + const QuicConfig& config, const ParsedQuicVersionVector& supported_versions, + QuicConnection* connection, QuicSession::Visitor* visitor, QuicCryptoServerStreamBase::Helper* helper, const QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache, QuicSimpleServerBackend* quic_simple_server_backend) - : QuicServerSessionBase(config, - supported_versions, - connection, - visitor, - helper, - crypto_config, - compressed_certs_cache), + : QuicServerSessionBase(config, supported_versions, connection, visitor, + helper, crypto_config, compressed_certs_cache), highest_promised_stream_id_( QuicUtils::GetInvalidStreamId(connection->transport_version())), quic_simple_server_backend_(quic_simple_server_backend) { QUICHE_DCHECK(quic_simple_server_backend_); } -QuicSimpleServerSession::~QuicSimpleServerSession() { - DeleteConnection(); -} +QuicSimpleServerSession::~QuicSimpleServerSession() { DeleteConnection(); } std::unique_ptr QuicSimpleServerSession::CreateQuicCryptoServerStream( @@ -62,39 +54,6 @@ void QuicSimpleServerSession::OnStreamFrame(const QuicStreamFrame& frame) { QuicSpdySession::OnStreamFrame(frame); } -void QuicSimpleServerSession::PromisePushResources( - const std::string& request_url, - const std::list& resources, - QuicStreamId original_stream_id, - const spdy::SpdyStreamPrecedence& /* original_precedence */, - const spdy::Http2HeaderBlock& original_request_headers) { - if (!server_push_enabled()) { - return; - } - - for (const QuicBackendResponse::ServerPushInfo& resource : resources) { - spdy::Http2HeaderBlock headers = SynthesizePushRequestHeaders( - request_url, resource, original_request_headers); - // TODO(b/136295430): Use sequential push IDs for IETF QUIC. - auto new_highest_promised_stream_id = - highest_promised_stream_id_ + - QuicUtils::StreamIdDelta(transport_version()); - if (VersionUsesHttp3(transport_version()) && - !CanCreatePushStreamWithId(new_highest_promised_stream_id)) { - return; - } - highest_promised_stream_id_ = new_highest_promised_stream_id; - SendPushPromise(original_stream_id, highest_promised_stream_id_, - headers.Clone()); - promised_streams_.push_back( - PromisedStreamInfo(std::move(headers), highest_promised_stream_id_, - spdy::SpdyStreamPrecedence(resource.priority))); - } - - // Procese promised push request as many as possible. - HandlePromisedPushRequests(); -} - QuicSpdyStream* QuicSimpleServerSession::CreateIncomingStream(QuicStreamId id) { if (!ShouldCreateIncomingStream(id)) { return nullptr; @@ -108,8 +67,8 @@ QuicSpdyStream* QuicSimpleServerSession::CreateIncomingStream(QuicStreamId id) { QuicSpdyStream* QuicSimpleServerSession::CreateIncomingStream( PendingStream* pending) { - QuicSpdyStream* stream = new QuicSimpleServerStream( - pending, this, BIDIRECTIONAL, quic_simple_server_backend_); + QuicSpdyStream* stream = + new QuicSimpleServerStream(pending, this, quic_simple_server_backend_); ActivateStream(absl::WrapUnique(stream)); return stream; } @@ -177,15 +136,15 @@ void QuicSimpleServerSession::HandleRstOnValidNonexistentStream( promised_streams_[index].is_cancelled = true; } } - control_frame_manager().WriteOrBufferRstStream(frame.stream_id, - QUIC_RST_ACKNOWLEDGEMENT, 0); + control_frame_manager().WriteOrBufferRstStream( + frame.stream_id, + QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), 0); connection()->OnStreamReset(frame.stream_id, QUIC_RST_ACKNOWLEDGEMENT); } } spdy::Http2HeaderBlock QuicSimpleServerSession::SynthesizePushRequestHeaders( - std::string request_url, - QuicBackendResponse::ServerPushInfo resource, + std::string request_url, QuicBackendResponse::ServerPushInfo resource, const spdy::Http2HeaderBlock& original_request_headers) { QuicUrl push_request_url = resource.request_url; diff --git a/gquiche/quic/tools/quic_simple_server_session.h b/gquiche/quic/tools/quic_simple_server_session.h index 7dba5dec..fab9b62e 100644 --- a/gquiche/quic/tools/quic_simple_server_session.h +++ b/gquiche/quic/tools/quic_simple_server_session.h @@ -68,19 +68,12 @@ class QuicSimpleServerSession : public QuicServerSessionBase { // Override base class to detact client sending data on server push stream. void OnStreamFrame(const QuicStreamFrame& frame) override; - // Send out PUSH_PROMISE for all |resources| promised stream id in each frame - // will increase by 2 for each item in |resources|. - // And enqueue HEADERS block in those PUSH_PROMISED for sending push response - // later. - virtual void PromisePushResources( - const std::string& request_url, - const std::list& resources, - QuicStreamId original_stream_id, - const spdy::SpdyStreamPrecedence& original_precedence, - const spdy::Http2HeaderBlock& original_request_headers); - void OnCanCreateNewOutgoingStream(bool unidirectional) override; + bool ShouldNegotiateDatagramContexts() override { + return quic_simple_server_backend_->UsesDatagramContexts(); + } + protected: // QuicSession methods: QuicSpdyStream* CreateIncomingStream(QuicStreamId id) override; @@ -107,9 +100,11 @@ class QuicSimpleServerSession : public QuicServerSessionBase { bool ShouldNegotiateWebTransport() override { return quic_simple_server_backend_->SupportsWebTransport(); } - bool ShouldNegotiateHttp3Datagram() override { - return QuicServerSessionBase::ShouldNegotiateHttp3Datagram() || - ShouldNegotiateWebTransport(); + HttpDatagramSupport LocalHttpDatagramSupport() override { + if (ShouldNegotiateWebTransport()) { + return HttpDatagramSupport::kDraft00And04; + } + return QuicServerSessionBase::LocalHttpDatagramSupport(); } private: @@ -155,7 +150,7 @@ class QuicSimpleServerSession : public QuicServerSessionBase { // the queue also increases by 2 from previous one's. The front element's // stream_id is always next_outgoing_stream_id_, and the last one is always // highest_promised_stream_id_. - QuicCircularDeque promised_streams_; + quiche::QuicheCircularDeque promised_streams_; QuicSimpleServerBackend* quic_simple_server_backend_; // Not owned. }; diff --git a/gquiche/quic/tools/quic_simple_server_session_test.cc b/gquiche/quic/tools/quic_simple_server_session_test.cc index 0fde2d76..75ba72c6 100644 --- a/gquiche/quic/tools/quic_simple_server_session_test.cc +++ b/gquiche/quic/tools/quic_simple_server_session_test.cc @@ -9,7 +9,6 @@ #include #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "gquiche/quic/core/crypto/null_encrypter.h" #include "gquiche/quic/core/crypto/quic_crypto_server_config.h" #include "gquiche/quic/core/crypto/quic_random.h" @@ -51,10 +50,11 @@ namespace quic { namespace test { namespace { -using PromisedStreamInfo = QuicSimpleServerSession::PromisedStreamInfo; - -const QuicByteCount kHeadersFrameHeaderLength = 2; -const QuicByteCount kHeadersFramePayloadLength = 9; +// Data to be sent on a request stream. In Google QUIC, this is interpreted as +// DATA payload (there is no framing on request streams). In IETF QUIC, this is +// interpreted as HEADERS frame (type 0x1) with payload length 122 ('z'). Since +// no payload is included, QPACK decoder will not be invoked. +const char* const kStreamData = "\1z"; } // namespace @@ -257,7 +257,7 @@ class QuicSimpleServerSessionTest kMaxStreamsForTest); } - ParsedQuicVersionVector supported_versions = SupportedVersions(GetParam()); + ParsedQuicVersionVector supported_versions = SupportedVersions(version()); connection_ = new StrictMock( &helper_, &alarm_factory_, Perspective::IS_SERVER, supported_versions); connection_->AdvanceTime(QuicTime::Delta::FromSeconds(1)); @@ -289,8 +289,10 @@ class QuicSimpleServerSessionTest transport_version(), n); } + ParsedQuicVersion version() const { return GetParam(); } + QuicTransportVersion transport_version() const { - return GetParam().transport_version; + return version().transport_version; } void InjectStopSending(QuicStreamId stream_id, @@ -330,10 +332,9 @@ INSTANTIATE_TEST_SUITE_P(Tests, ::testing::PrintToStringParamName()); TEST_P(QuicSimpleServerSessionTest, CloseStreamDueToReset) { - // Open a stream, then reset it. - // Send two bytes of payload to open it. + // Send some data open a stream, then reset it. QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, - absl::string_view("HT")); + kStreamData); session_->OnStreamFrame(data1); EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); @@ -388,9 +389,8 @@ TEST_P(QuicSimpleServerSessionTest, NeverOpenStreamDueToReset) { EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); - // Send two bytes of payload. QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, - absl::string_view("HT")); + kStreamData); session_->OnStreamFrame(data1); // The stream should never be opened, now that the reset is received. @@ -399,11 +399,11 @@ TEST_P(QuicSimpleServerSessionTest, NeverOpenStreamDueToReset) { } TEST_P(QuicSimpleServerSessionTest, AcceptClosedStream) { - // Send (empty) compressed headers followed by two bytes of data. + // Send some data to open two streams. QuicStreamFrame frame1(GetNthClientInitiatedBidirectionalId(0), false, 0, - absl::string_view("\1\0\0\0\0\0\0\0HT")); + kStreamData); QuicStreamFrame frame2(GetNthClientInitiatedBidirectionalId(1), false, 0, - absl::string_view("\3\0\0\0\0\0\0\0HT")); + kStreamData); session_->OnStreamFrame(frame1); session_->OnStreamFrame(frame2); EXPECT_EQ(2u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); @@ -431,9 +431,9 @@ TEST_P(QuicSimpleServerSessionTest, AcceptClosedStream) { // past the reset point of stream 3. As it's a closed stream we just drop the // data on the floor, but accept the packet because it has data for stream 5. QuicStreamFrame frame3(GetNthClientInitiatedBidirectionalId(0), false, 2, - absl::string_view("TP")); + kStreamData); QuicStreamFrame frame4(GetNthClientInitiatedBidirectionalId(1), false, 2, - absl::string_view("TP")); + kStreamData); session_->OnStreamFrame(frame3); session_->OnStreamFrame(frame4); // The stream should never be opened, now that the reset is received. @@ -443,7 +443,7 @@ TEST_P(QuicSimpleServerSessionTest, AcceptClosedStream) { TEST_P(QuicSimpleServerSessionTest, CreateIncomingStreamDisconnected) { // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. - if (GetParam() != AllSupportedVersions()[0]) { + if (version() != AllSupportedVersions()[0]) { return; } @@ -467,7 +467,7 @@ TEST_P(QuicSimpleServerSessionTest, CreateIncomingStream) { TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamDisconnected) { // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. - if (GetParam() != AllSupportedVersions()[0]) { + if (version() != AllSupportedVersions()[0]) { return; } @@ -486,7 +486,7 @@ TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamDisconnected) { TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamUnencrypted) { // EXPECT_QUIC_BUG tests are expensive so only run one instance of them. - if (GetParam() != AllSupportedVersions()[0]) { + if (version() != AllSupportedVersions()[0]) { return; } @@ -510,7 +510,7 @@ TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamUptoLimit) { // Receive some data to initiate a incoming stream which should not effect // creating outgoing streams. QuicStreamFrame data1(GetNthClientInitiatedBidirectionalId(0), false, 0, - absl::string_view("HT")); + kStreamData); session_->OnStreamFrame(data1); EXPECT_EQ(1u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); EXPECT_EQ(0u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get()) - @@ -559,7 +559,7 @@ TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamUptoLimit) { // Create peer initiated stream should have no problem. QuicStreamFrame data2(GetNthClientInitiatedBidirectionalId(1), false, 0, - absl::string_view("HT")); + kStreamData); session_->OnStreamFrame(data2); EXPECT_EQ(2u, QuicSessionPeer::GetNumOpenDynamicStreams(session_.get()) - /*outcoming=*/kMaxStreamsForTest); @@ -567,7 +567,7 @@ TEST_P(QuicSimpleServerSessionTest, CreateOutgoingDynamicStreamUptoLimit) { TEST_P(QuicSimpleServerSessionTest, OnStreamFrameWithEvenStreamId) { QuicStreamFrame frame(GetNthServerInitiatedUnidirectionalId(0), false, 0, - absl::string_view()); + kStreamData); EXPECT_CALL(*connection_, CloseConnection(QUIC_INVALID_STREAM_ID, "Client sent data on server push stream", _)); @@ -591,451 +591,6 @@ TEST_P(QuicSimpleServerSessionTest, GetEvenIncomingError) { QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); } -// In order to test the case where server push stream creation goes beyond -// limit, server push streams need to be hanging there instead of -// immediately closing after sending back response. -// To achieve this goal, this class resets flow control windows so that large -// responses will not be sent fully in order to prevent push streams from being -// closed immediately. -// Also adjust connection-level flow control window to ensure a large response -// can cause stream-level flow control blocked but not connection-level. -class QuicSimpleServerSessionServerPushTest - : public QuicSimpleServerSessionTest { - protected: - const size_t kStreamFlowControlWindowSize = 32 * 1024; // 32KB. - - QuicSimpleServerSessionServerPushTest() { - // Reset stream level flow control window to be 32KB. - if (GetParam().handshake_protocol == PROTOCOL_TLS1_3) { - if (VersionHasIetfQuicFrames(transport_version())) { - QuicConfigPeer::SetReceivedInitialMaxStreamDataBytesUnidirectional( - &config_, kStreamFlowControlWindowSize); - } else { - // In this version, push streams are server-initiated bidirectional - // streams, which are outgoing since we are the server here. - QuicConfigPeer:: - SetReceivedInitialMaxStreamDataBytesOutgoingBidirectional( - &config_, kStreamFlowControlWindowSize); - } - } else { - QuicConfigPeer::SetReceivedInitialStreamFlowControlWindow( - &config_, kStreamFlowControlWindowSize); - } - // Reset connection level flow control window to be 1.5 MB which is large - // enough that it won't block any stream to write before stream level flow - // control blocks it. - QuicConfigPeer::SetReceivedInitialSessionFlowControlWindow( - &config_, kInitialSessionFlowControlWindowForTest); - - ParsedQuicVersionVector supported_versions = SupportedVersions(GetParam()); - connection_ = new StrictMock( - &helper_, &alarm_factory_, Perspective::IS_SERVER, supported_versions); - connection_->SetEncrypter( - ENCRYPTION_FORWARD_SECURE, - std::make_unique(connection_->perspective())); - session_ = std::make_unique( - config_, connection_, &owner_, &stream_helper_, &crypto_config_, - &compressed_certs_cache_, &memory_cache_backend_); - session_->Initialize(); - // Needed to make new session flow control window and server push work. - - if (VersionHasIetfQuicFrames(transport_version())) { - EXPECT_CALL(*session_, WriteControlFrame(_, _)) - .WillRepeatedly(Invoke(&ClearControlFrameWithTransmissionType)); - } - session_->OnConfigNegotiated(); - - if (!VersionUsesHttp3(transport_version())) { - session_->UnregisterStreamPriority( - QuicUtils::GetHeadersStreamId(transport_version()), - /*is_static=*/true); - } - QuicSimpleServerSessionPeer::SetCryptoStream(session_.get(), nullptr); - // Assume encryption already established. - QuicCryptoServerStreamBase* crypto_stream = - CreateMockCryptoServerStream(&crypto_config_, &compressed_certs_cache_, - session_.get(), &stream_helper_); - - QuicSimpleServerSessionPeer::SetCryptoStream(session_.get(), crypto_stream); - if (!VersionUsesHttp3(transport_version())) { - session_->RegisterStreamPriority( - QuicUtils::GetHeadersStreamId(transport_version()), - /*is_static=*/true, - spdy::SpdyStreamPrecedence(QuicStream::kDefaultPriority)); - } - if (VersionUsesHttp3(transport_version())) { - // Ignore writes on the control stream. - auto send_control_stream = - QuicSpdySessionPeer::GetSendControlStream(session_.get()); - EXPECT_CALL(*connection_, - SendStreamData(send_control_stream->id(), _, _, NO_FIN)) - .Times(AnyNumber()); - } - } - - // Given |num_resources|, create this number of fake push resources and push - // them by sending PUSH_PROMISE for all and sending push responses for as much - // as possible(limited by kMaxStreamsForTest). - // If |num_resources| > kMaxStreamsForTest, the left over will be queued. - // Returns the length of the DATA frame header, or 0 if the version does not - // use DATA frames. - QuicByteCount PromisePushResources(size_t num_resources) { - // testing::InSequence seq; - // To prevent push streams from being closed the response need to be larger - // than stream flow control window so stream won't send the full body. - size_t body_size = 2 * kStreamFlowControlWindowSize; // 64KB. - - std::string request_url = "mail.google.com/"; - spdy::Http2HeaderBlock request_headers; - std::string resource_host = "www.google.com"; - std::string partial_push_resource_path = "/server_push_src"; - std::list push_resources; - std::string scheme = "http"; - QuicByteCount data_frame_header_length = 0; - for (unsigned int i = 1; i <= num_resources; ++i) { - QuicStreamId stream_id; - if (VersionUsesHttp3(transport_version())) { - stream_id = GetNthServerInitiatedUnidirectionalId(i + 2); - } else { - stream_id = GetNthServerInitiatedUnidirectionalId(i - 1); - } - std::string path = absl::StrCat(partial_push_resource_path, i); - std::string url = scheme + "://" + resource_host + path; - QuicUrl resource_url = QuicUrl(url); - std::string body(body_size, 'a'); - std::string data; - data_frame_header_length = 0; - if (VersionUsesHttp3(transport_version())) { - std::unique_ptr buffer; - data_frame_header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); - std::string header(buffer.get(), data_frame_header_length); - data = header + body; - } else { - data = body; - } - - memory_cache_backend_.AddSimpleResponse(resource_host, path, 200, data); - push_resources.push_back(QuicBackendResponse::ServerPushInfo( - resource_url, spdy::Http2HeaderBlock(), QuicStream::kDefaultPriority, - body)); - // PUSH_PROMISED are sent for all the resources. - EXPECT_CALL(*session_, - WritePushPromiseMock(GetNthClientInitiatedBidirectionalId(0), - stream_id, _)); - if (i <= kMaxStreamsForTest) { - // |kMaxStreamsForTest| promised responses should be sent. - // Since flow control window is smaller than response body, not the - // whole body will be sent. - QuicStreamOffset offset = 0; - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(stream_id, 1, offset, NO_FIN)); - offset++; - } - - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(stream_id, kHeadersFrameHeaderLength, - offset, NO_FIN)); - offset += kHeadersFrameHeaderLength; - EXPECT_CALL(*connection_, - SendStreamData(stream_id, kHeadersFramePayloadLength, - offset, NO_FIN)); - offset += kHeadersFramePayloadLength; - } - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(stream_id, data_frame_header_length, - offset, NO_FIN)); - offset += data_frame_header_length; - } - EXPECT_CALL(*connection_, SendStreamData(stream_id, _, offset, NO_FIN)) - .WillOnce(Return(QuicConsumedData( - kStreamFlowControlWindowSize - offset, false))); - EXPECT_CALL(*session_, SendBlocked(stream_id)); - } - } - session_->PromisePushResources( - request_url, push_resources, GetNthClientInitiatedBidirectionalId(0), - spdy::SpdyStreamPrecedence(0, spdy::kHttp2DefaultStreamWeight, false), - request_headers); - return data_frame_header_length; - } - - void MaybeConsumeHeadersStreamData() { - if (!VersionUsesHttp3(transport_version())) { - QuicStreamId headers_stream_id = - QuicUtils::GetHeadersStreamId(transport_version()); - EXPECT_CALL(*connection_, SendStreamData(headers_stream_id, _, _, _)) - .Times(AtLeast(1)); - } - } -}; - -ParsedQuicVersionVector SupportedVersionsWithPush() { - ParsedQuicVersionVector versions; - for (const ParsedQuicVersion& version : AllSupportedVersions()) { - if (!version.UsesHttp3()) { - // Push over HTTP/3 is not supported. - versions.push_back(version); - } - } - return versions; -} - -INSTANTIATE_TEST_SUITE_P(Tests, - QuicSimpleServerSessionServerPushTest, - ::testing::ValuesIn(SupportedVersionsWithPush())); - -// Tests that given more than kMaxStreamsForTest resources, all their -// PUSH_PROMISE's will be sent out and only kMaxStreamsForTest streams will be -// opened and send push response. -TEST_P(QuicSimpleServerSessionServerPushTest, TestPromisePushResources) { - MaybeConsumeHeadersStreamData(); - size_t num_resources = kMaxStreamsForTest + 5; - PromisePushResources(num_resources); - EXPECT_EQ(kMaxStreamsForTest, - QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); -} - -// Tests that after promised stream queued up, when an opened stream is marked -// draining, a queued promised stream will become open and send push response. -TEST_P(QuicSimpleServerSessionServerPushTest, - HandlePromisedPushRequestsAfterStreamDraining) { - MaybeConsumeHeadersStreamData(); - size_t num_resources = kMaxStreamsForTest + 1; - QuicByteCount data_frame_header_length = PromisePushResources(num_resources); - QuicStreamId next_out_going_stream_id; - if (VersionUsesHttp3(transport_version())) { - next_out_going_stream_id = - GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest + 3); - } else { - next_out_going_stream_id = - GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest); - } - - // After an open stream is marked draining, a new stream is expected to be - // created and a response sent on the stream. - QuicStreamOffset offset = 0; - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(next_out_going_stream_id, 1, offset, NO_FIN)); - offset++; - } - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(next_out_going_stream_id, - kHeadersFrameHeaderLength, offset, NO_FIN)); - offset += kHeadersFrameHeaderLength; - EXPECT_CALL(*connection_, - SendStreamData(next_out_going_stream_id, - kHeadersFramePayloadLength, offset, NO_FIN)); - offset += kHeadersFramePayloadLength; - } - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(next_out_going_stream_id, - data_frame_header_length, offset, NO_FIN)); - offset += data_frame_header_length; - } - EXPECT_CALL(*connection_, - SendStreamData(next_out_going_stream_id, _, offset, NO_FIN)) - .WillOnce(Return( - QuicConsumedData(kStreamFlowControlWindowSize - offset, false))); - EXPECT_CALL(*session_, SendBlocked(next_out_going_stream_id)); - - if (VersionHasIetfQuicFrames(transport_version())) { - // The PromisePushedResources call, above, will have used all available - // stream ids. For version 99, stream ids are not made available until - // a MAX_STREAMS frame is received. This emulates the reception of one. - // For pre-v-99, the node monitors its own stream usage and makes streams - // available as it closes/etc them. - // Version 99 also has unidirectional static streams, so we need to send - // MaxStreamFrame of the number of resources + number of static streams. - session_->OnMaxStreamsFrame( - QuicMaxStreamsFrame(0, num_resources + 3, /*unidirectional=*/true)); - } - - if (VersionUsesHttp3(transport_version())) { - session_->StreamDraining(GetNthServerInitiatedUnidirectionalId(3), - /*unidirectional=*/true); - } else { - session_->StreamDraining(GetNthServerInitiatedUnidirectionalId(0), - /*unidirectional=*/true); - } - // Number of open outgoing streams should still be the same, because a new - // stream is opened. And the queue should be empty. - EXPECT_EQ(kMaxStreamsForTest, - QuicSessionPeer::GetNumOpenDynamicStreams(session_.get())); -} - -// Tests that after all resources are promised, a RST frame from client can -// prevent a promised resource to be send out. -TEST_P(QuicSimpleServerSessionServerPushTest, - ResetPromisedStreamToCancelServerPush) { - if (VersionHasIetfQuicFrames(transport_version())) { - // This test is resetting a stream that is not opened yet. IETF QUIC has no - // way to handle this. Some similar tests can be added once CANCEL_PUSH is - // supported. - return; - } - MaybeConsumeHeadersStreamData(); - - // Having two extra resources to be send later. One of them will be reset, so - // when opened stream become close, only one will become open. - size_t num_resources = kMaxStreamsForTest + 2; - if (VersionHasIetfQuicFrames(transport_version())) { - // V99 will send out a STREAMS_BLOCKED frame when it tries to exceed the - // limit. This will clear the frames so that they do not block the later - // rst-stream frame. - EXPECT_CALL(*session_, WriteControlFrame(_, _)) - .WillOnce(Invoke(&ClearControlFrameWithTransmissionType)); - } - QuicByteCount data_frame_header_length = PromisePushResources(num_resources); - - // Reset the last stream in the queue. It should be marked cancelled. - QuicStreamId stream_got_reset; - if (VersionUsesHttp3(transport_version())) { - stream_got_reset = - GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest + 4); - } else { - stream_got_reset = - GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest + 1); - } - QuicRstStreamFrame rst(kInvalidControlFrameId, stream_got_reset, - QUIC_STREAM_CANCELLED, 0); - EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); - EXPECT_CALL(*session_, WriteControlFrame(_, _)) - .WillOnce(Invoke(&ClearControlFrameWithTransmissionType)); - EXPECT_CALL(*connection_, - OnStreamReset(stream_got_reset, QUIC_RST_ACKNOWLEDGEMENT)); - session_->OnRstStream(rst); - - // When the first 2 streams becomes draining, the two queued up stream could - // be created. But since one of them was marked cancelled due to RST frame, - // only one queued resource will be sent out. - QuicStreamId stream_not_reset; - if (VersionUsesHttp3(transport_version())) { - stream_not_reset = - GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest + 3); - } else { - stream_not_reset = - GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest); - } - InSequence s; - QuicStreamOffset offset = 0; - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(stream_not_reset, 1, offset, NO_FIN)); - offset++; - EXPECT_CALL(*connection_, - SendStreamData(stream_not_reset, kHeadersFrameHeaderLength, - offset, NO_FIN)); - offset += kHeadersFrameHeaderLength; - EXPECT_CALL(*connection_, - SendStreamData(stream_not_reset, kHeadersFramePayloadLength, - offset, NO_FIN)); - offset += kHeadersFramePayloadLength; - EXPECT_CALL(*connection_, - SendStreamData(stream_not_reset, data_frame_header_length, - offset, NO_FIN)); - offset += data_frame_header_length; - } - EXPECT_CALL(*connection_, SendStreamData(stream_not_reset, _, offset, NO_FIN)) - .WillOnce(Return( - QuicConsumedData(kStreamFlowControlWindowSize - offset, false))); - EXPECT_CALL(*session_, SendBlocked(stream_not_reset)); - - if (VersionHasIetfQuicFrames(transport_version())) { - // The PromisePushedResources call, above, will have used all available - // stream ids. For version 99, stream ids are not made available until - // a MAX_STREAMS frame is received. This emulates the reception of one. - // For pre-v-99, the node monitors its own stream usage and makes streams - // available as it closes/etc them. - session_->OnMaxStreamsFrame( - QuicMaxStreamsFrame(0, num_resources + 3, /*unidirectional=*/true)); - } - session_->StreamDraining(GetNthServerInitiatedUnidirectionalId(3), - /*unidirectional=*/true); - session_->StreamDraining(GetNthServerInitiatedUnidirectionalId(4), - /*unidirectional=*/true); -} - -// Tests that closing a open outgoing stream can trigger a promised resource in -// the queue to be send out. -TEST_P(QuicSimpleServerSessionServerPushTest, - CloseStreamToHandleMorePromisedStream) { - MaybeConsumeHeadersStreamData(); - size_t num_resources = kMaxStreamsForTest + 1; - if (VersionHasIetfQuicFrames(transport_version())) { - // V99 will send out a stream-id-blocked frame when the we desired to exceed - // the limit. This will clear the frames so that they do not block the later - // rst-stream frame. - EXPECT_CALL(*session_, WriteControlFrame(_, _)) - .WillOnce(Invoke(&ClearControlFrameWithTransmissionType)); - } - QuicByteCount data_frame_header_length = PromisePushResources(num_resources); - QuicStreamId stream_to_open; - if (VersionUsesHttp3(transport_version())) { - stream_to_open = - GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest + 3); - } else { - stream_to_open = GetNthServerInitiatedUnidirectionalId(kMaxStreamsForTest); - } - - // Resetting an open stream will close the stream and give space for extra - // stream to be opened. - QuicStreamId stream_got_reset = GetNthServerInitiatedUnidirectionalId(3); - EXPECT_CALL(*session_, WriteControlFrame(_, _)); - if (!VersionHasIetfQuicFrames(transport_version())) { - EXPECT_CALL(owner_, OnRstStreamReceived(_)).Times(1); - // For version 99, this is covered in InjectStopSending() - EXPECT_CALL(*connection_, - OnStreamReset(stream_got_reset, QUIC_RST_ACKNOWLEDGEMENT)); - } - QuicStreamOffset offset = 0; - if (VersionUsesHttp3(transport_version())) { - EXPECT_CALL(*connection_, - SendStreamData(stream_to_open, 1, offset, NO_FIN)); - offset++; - EXPECT_CALL(*connection_, - SendStreamData(stream_to_open, kHeadersFrameHeaderLength, - offset, NO_FIN)); - offset += kHeadersFrameHeaderLength; - EXPECT_CALL(*connection_, - SendStreamData(stream_to_open, kHeadersFramePayloadLength, - offset, NO_FIN)); - offset += kHeadersFramePayloadLength; - EXPECT_CALL(*connection_, - SendStreamData(stream_to_open, data_frame_header_length, offset, - NO_FIN)); - offset += data_frame_header_length; - } - EXPECT_CALL(*connection_, SendStreamData(stream_to_open, _, offset, NO_FIN)) - .WillOnce(Return( - QuicConsumedData(kStreamFlowControlWindowSize - offset, false))); - - EXPECT_CALL(*session_, SendBlocked(stream_to_open)); - QuicRstStreamFrame rst(kInvalidControlFrameId, stream_got_reset, - QUIC_STREAM_CANCELLED, 0); - if (VersionHasIetfQuicFrames(transport_version())) { - // The PromisePushedResources call, above, will have used all available - // stream ids. For version 99, stream ids are not made available until - // a MAX_STREAMS frame is received. This emulates the reception of one. - // For pre-v-99, the node monitors its own stream usage and makes streams - // available as it closes/etc them. - session_->OnMaxStreamsFrame( - QuicMaxStreamsFrame(0, num_resources + 3, /*unidirectional=*/true)); - } else { - session_->OnRstStream(rst); - } - // Create and inject a STOP_SENDING frame. In GOOGLE QUIC, receiving a - // RST_STREAM frame causes a two-way close. For IETF QUIC, RST_STREAM causes - // a one-way close. - InjectStopSending(stream_got_reset, QUIC_STREAM_CANCELLED); -} - } // namespace } // namespace test } // namespace quic diff --git a/gquiche/quic/tools/quic_simple_server_stream.cc b/gquiche/quic/tools/quic_simple_server_stream.cc index e0de378d..c83fac9f 100644 --- a/gquiche/quic/tools/quic_simple_server_stream.cc +++ b/gquiche/quic/tools/quic_simple_server_stream.cc @@ -18,7 +18,6 @@ #include "gquiche/quic/platform/api/quic_bug_tracker.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/quic/platform/api/quic_map_util.h" #include "gquiche/quic/tools/quic_simple_server_session.h" #include "gquiche/spdy/core/spdy_protocol.h" @@ -39,11 +38,9 @@ QuicSimpleServerStream::QuicSimpleServerStream( } QuicSimpleServerStream::QuicSimpleServerStream( - PendingStream* pending, - QuicSpdySession* session, - StreamType type, + PendingStream* pending, QuicSpdySession* session, QuicSimpleServerBackend* quic_simple_server_backend) - : QuicSpdyServerStreamBase(pending, session, type), + : QuicSpdyServerStreamBase(pending, session), content_length_(-1), generate_bytes_length_(0), quic_simple_server_backend_(quic_simple_server_backend) { @@ -59,10 +56,13 @@ void QuicSimpleServerStream::OnInitialHeadersComplete( size_t frame_len, const QuicHeaderList& header_list) { QuicSpdyStream::OnInitialHeadersComplete(fin, frame_len, header_list); - if (!SpdyUtils::CopyAndValidateHeaders(header_list, &content_length_, + // QuicSpdyStream::OnInitialHeadersComplete() may have already sent error + // response. + if (!response_sent_ && + !SpdyUtils::CopyAndValidateHeaders(header_list, &content_length_, &request_headers_)) { QUIC_DVLOG(1) << "Invalid headers"; - SendErrorResponse(); + SendErrorResponse(); } ConsumeHeaderList(); if (!fin && !response_sent_) { @@ -155,13 +155,13 @@ void QuicSimpleServerStream::SendResponse() { return; } - if (!QuicContainsKey(request_headers_, ":authority")) { + if (!request_headers_.contains(":authority")) { QUIC_DVLOG(1) << "Request headers do not contain :authority."; SendErrorResponse(); return; } - if (!QuicContainsKey(request_headers_, ":path")) { + if (!request_headers_.contains(":path")) { // CONNECT and other CONNECT-like methods (such as CONNECT-UDP) do not all // require :path to be present. auto it = request_headers_.find(":method"); @@ -214,8 +214,7 @@ std::string QuicSimpleServerStream::peer_host() const { } void QuicSimpleServerStream::OnResponseBackendComplete( - const QuicBackendResponse* response, - std::list resources) { + const QuicBackendResponse* response) { if (response == nullptr) { QUIC_DVLOG(1) << "Response not found in cache."; SendNotFoundResponse(); @@ -285,15 +284,6 @@ void QuicSimpleServerStream::OnResponseBackendComplete( } } - if (!resources.empty()) { - QUIC_DVLOG(1) << "Stream " << id() << " found " << resources.size() - << " push resources."; - QuicSimpleServerSession* session = - static_cast(spdy_session()); - session->PromisePushResources(request_url, resources, id(), precedence(), - request_headers_); - } - if (response->response_type() == QuicBackendResponse::INCOMPLETE_RESPONSE) { QUIC_DVLOG(1) << "Stream " << id() @@ -357,6 +347,9 @@ void QuicSimpleServerStream::SendErrorResponse() { void QuicSimpleServerStream::SendErrorResponse(int resp_code) { QUIC_DVLOG(1) << "Stream " << id() << " sending error response."; + if (!reading_stopped()) { + StopReading(); + } Http2HeaderBlock headers; if (resp_code <= 0) { headers[":status"] = "500"; @@ -424,6 +417,11 @@ void QuicSimpleServerStream::SendHeadersAndBodyAndTrailers( WriteTrailers(std::move(response_trailers), nullptr); } +void QuicSimpleServerStream::OnInvalidHeaders() { + QUIC_DVLOG(1) << "Invalid headers"; + SendErrorResponse(400); +} + const char* const QuicSimpleServerStream::kErrorResponseBody = "bad"; const char* const QuicSimpleServerStream::kNotFoundResponseBody = "file not found"; diff --git a/gquiche/quic/tools/quic_simple_server_stream.h b/gquiche/quic/tools/quic_simple_server_stream.h index fa4d48bb..95295ff0 100644 --- a/gquiche/quic/tools/quic_simple_server_stream.h +++ b/gquiche/quic/tools/quic_simple_server_stream.h @@ -25,7 +25,6 @@ class QuicSimpleServerStream : public QuicSpdyServerStreamBase, QuicSimpleServerBackend* quic_simple_server_backend); QuicSimpleServerStream(PendingStream* pending, QuicSpdySession* session, - StreamType type, QuicSimpleServerBackend* quic_simple_server_backend); QuicSimpleServerStream(const QuicSimpleServerStream&) = delete; QuicSimpleServerStream& operator=(const QuicSimpleServerStream&) = delete; @@ -44,6 +43,8 @@ class QuicSimpleServerStream : public QuicSpdyServerStreamBase, // data (or a FIN) to be read. void OnBodyAvailable() override; + void OnInvalidHeaders() override; + // Make this stream start from as if it just finished parsing an incoming // request whose headers are equivalent to |push_request_headers|. // Doing so will trigger this toy stream to fetch response and send it back. @@ -57,9 +58,7 @@ class QuicSimpleServerStream : public QuicSpdyServerStreamBase, QuicConnectionId connection_id() const override; QuicStreamId stream_id() const override; std::string peer_host() const override; - void OnResponseBackendComplete( - const QuicBackendResponse* response, - std::list resources) override; + void OnResponseBackendComplete(const QuicBackendResponse* response) override; protected: // Sends a basic 200 response using SendHeaders for the headers and WriteData @@ -69,7 +68,7 @@ class QuicSimpleServerStream : public QuicSpdyServerStreamBase, // Sends a basic 500 response using SendHeaders for the headers and WriteData // for the body. virtual void SendErrorResponse(); - void SendErrorResponse(int resp_code); + virtual void SendErrorResponse(int resp_code); // Sends a basic 404 response using SendHeaders for the headers and WriteData // for the body. diff --git a/gquiche/quic/tools/quic_simple_server_stream_test.cc b/gquiche/quic/tools/quic_simple_server_stream_test.cc index 77163d83..602077ff 100644 --- a/gquiche/quic/tools/quic_simple_server_stream_test.cc +++ b/gquiche/quic/tools/quic_simple_server_stream_test.cc @@ -16,6 +16,7 @@ #include "gquiche/quic/core/http/http_encoder.h" #include "gquiche/quic/core/http/spdy_utils.h" #include "gquiche/quic/core/quic_error_codes.h" +#include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/core/quic_types.h" #include "gquiche/quic/core/quic_utils.h" #include "gquiche/quic/platform/api/quic_expect_bug.h" @@ -47,13 +48,9 @@ const size_t kDataFrameHeaderLength = 2; class TestStream : public QuicSimpleServerStream { public: - TestStream(QuicStreamId stream_id, - QuicSpdySession* session, - StreamType type, + TestStream(QuicStreamId stream_id, QuicSpdySession* session, StreamType type, QuicSimpleServerBackend* quic_simple_server_backend) - : QuicSimpleServerStream(stream_id, - session, - type, + : QuicSimpleServerStream(stream_id, session, type, quic_simple_server_backend) {} ~TestStream() override = default; @@ -61,8 +58,7 @@ class TestStream : public QuicSimpleServerStream { MOCK_METHOD(void, WriteHeadersMock, (bool fin), ()); MOCK_METHOD(void, WriteEarlyHintsHeadersMock, (bool fin), ()); - size_t WriteHeaders(spdy::Http2HeaderBlock header_block, - bool fin, + size_t WriteHeaders(spdy::Http2HeaderBlock header_block, bool fin, QuicReferenceCountedPointer /*ack_listener*/) override { if (header_block[":status"] == "103") { @@ -75,7 +71,7 @@ class TestStream : public QuicSimpleServerStream { // Expose protected QuicSimpleServerStream methods. void DoSendResponse() { SendResponse(); } - void DoSendErrorResponse() { SendErrorResponse(); } + void DoSendErrorResponse() { QuicSimpleServerStream::SendErrorResponse(); } spdy::Http2HeaderBlock* mutable_headers() { return &request_headers_; } void set_body(std::string body) { body_ = std::move(body); } @@ -98,9 +94,9 @@ class TestStream : public QuicSimpleServerStream { QuicSimpleServerStream::SendResponse(); } - void SendErrorResponse() override { + void SendErrorResponse(int resp_code) override { send_error_response_was_called_ = true; - QuicSimpleServerStream::SendErrorResponse(); + QuicSimpleServerStream::SendErrorResponse(resp_code); } private: @@ -115,18 +111,13 @@ class MockQuicSimpleServerSession : public QuicSimpleServerSession { const size_t kMaxStreamsForTest = 100; MockQuicSimpleServerSession( - QuicConnection* connection, - MockQuicSessionVisitor* owner, + QuicConnection* connection, MockQuicSessionVisitor* owner, MockQuicCryptoServerStreamHelper* helper, QuicCryptoServerConfig* crypto_config, QuicCompressedCertsCache* compressed_certs_cache, QuicSimpleServerBackend* quic_simple_server_backend) - : QuicSimpleServerSession(DefaultQuicConfig(), - CurrentSupportedVersions(), - connection, - owner, - helper, - crypto_config, + : QuicSimpleServerSession(DefaultQuicConfig(), CurrentSupportedVersions(), + connection, owner, helper, crypto_config, compressed_certs_cache, quic_simple_server_backend) { if (VersionHasIetfQuicFrames(connection->transport_version())) { @@ -147,70 +138,35 @@ class MockQuicSimpleServerSession : public QuicSimpleServerSession { delete; ~MockQuicSimpleServerSession() override = default; - MOCK_METHOD(void, - OnConnectionClosed, + MOCK_METHOD(void, OnConnectionClosed, (const QuicConnectionCloseFrame& frame, ConnectionCloseSource source), (override)); - MOCK_METHOD(QuicSpdyStream*, - CreateIncomingStream, - (QuicStreamId id), + MOCK_METHOD(QuicSpdyStream*, CreateIncomingStream, (QuicStreamId id), (override)); - MOCK_METHOD(QuicConsumedData, - WritevData, - (QuicStreamId id, - size_t write_length, - QuicStreamOffset offset, - StreamSendingState state, - TransmissionType type, - absl::optional level), + MOCK_METHOD(QuicConsumedData, WritevData, + (QuicStreamId id, size_t write_length, QuicStreamOffset offset, + StreamSendingState state, TransmissionType type, + EncryptionLevel level), (override)); - MOCK_METHOD(void, - OnStreamHeaderList, - (QuicStreamId stream_id, - bool fin, - size_t frame_len, + MOCK_METHOD(void, OnStreamHeaderList, + (QuicStreamId stream_id, bool fin, size_t frame_len, const QuicHeaderList& header_list), (override)); - MOCK_METHOD(void, - OnStreamHeadersPriority, + MOCK_METHOD(void, OnStreamHeadersPriority, (QuicStreamId stream_id, const spdy::SpdyStreamPrecedence& precedence), (override)); - MOCK_METHOD(void, - MaybeSendRstStreamFrame, - (QuicStreamId stream_id, - QuicRstStreamErrorCode error, + MOCK_METHOD(void, MaybeSendRstStreamFrame, + (QuicStreamId stream_id, QuicResetStreamError error, QuicStreamOffset bytes_written), (override)); - MOCK_METHOD(void, - MaybeSendStopSendingFrame, - (QuicStreamId stream_id, QuicRstStreamErrorCode error), - (override)); - // Matchers cannot be used on non-copyable types like Http2HeaderBlock. - void PromisePushResources( - const std::string& request_url, - const std::list& resources, - QuicStreamId original_stream_id, - const spdy::SpdyStreamPrecedence& original_precedence, - const spdy::Http2HeaderBlock& original_request_headers) override { - original_request_headers_ = original_request_headers.Clone(); - PromisePushResourcesMock(request_url, resources, original_stream_id, - original_precedence, original_request_headers); - } - MOCK_METHOD(void, - PromisePushResourcesMock, - (const std::string&, - const std::list&, - QuicStreamId, - const spdy::SpdyStreamPrecedence&, - const spdy::Http2HeaderBlock&), - ()); + MOCK_METHOD(void, MaybeSendStopSendingFrame, + (QuicStreamId stream_id, QuicResetStreamError error), (override)); using QuicSession::ActivateStream; - QuicConsumedData ConsumeData(QuicStreamId id, - size_t write_length, + QuicConsumedData ConsumeData(QuicStreamId id, size_t write_length, QuicStreamOffset offset, StreamSendingState state, TransmissionType /*type*/, @@ -233,23 +189,17 @@ class MockQuicSimpleServerSession : public QuicSimpleServerSession { class QuicSimpleServerStreamTest : public QuicTestWithParam { public: QuicSimpleServerStreamTest() - : connection_( - new StrictMock(&helper_, - &alarm_factory_, - Perspective::IS_SERVER, - SupportedVersions(GetParam()))), + : connection_(new StrictMock( + &helper_, &alarm_factory_, Perspective::IS_SERVER, + SupportedVersions(GetParam()))), crypto_config_(new QuicCryptoServerConfig( - QuicCryptoServerConfig::TESTING, - QuicRandom::GetInstance(), + QuicCryptoServerConfig::TESTING, QuicRandom::GetInstance(), crypto_test_utils::ProofSourceForTesting(), KeyExchangeSource::Default())), compressed_certs_cache_( QuicCompressedCertsCache::kQuicCompressedCertsCacheSize), - session_(connection_, - &session_owner_, - &session_helper_, - crypto_config_.get(), - &compressed_certs_cache_, + session_(connection_, &session_owner_, &session_helper_, + crypto_config_.get(), &compressed_certs_cache_, &memory_cache_backend_), quic_response_(new QuicBackendResponse), body_("hello world") { @@ -258,6 +208,7 @@ class QuicSimpleServerStreamTest : public QuicTestWithParam { header_list_.OnHeader(":authority", "www.google.com"); header_list_.OnHeader(":path", "/"); header_list_.OnHeader(":method", "POST"); + header_list_.OnHeader(":scheme", "https"); header_list_.OnHeader("content-length", "11"); header_list_.OnHeaderBlockEnd(128, 128); @@ -320,8 +271,7 @@ class QuicSimpleServerStreamTest : public QuicTestWithParam { QuicHeaderList header_list_; }; -INSTANTIATE_TEST_SUITE_P(Tests, - QuicSimpleServerStreamTest, +INSTANTIATE_TEST_SUITE_P(Tests, QuicSimpleServerStreamTest, ::testing::ValuesIn(AllSupportedVersions()), ::testing::PrintToStringParamName()); @@ -330,11 +280,10 @@ TEST_P(QuicSimpleServerStreamTest, TestFraming) { .WillRepeatedly( Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); stream_->OnStreamHeaderList(false, kFakeFrameLen, header_list_); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - std::string data = UsesHttp3() ? header + body_ : body_; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); + std::string data = + UsesHttp3() ? absl::StrCat(header.AsStringView(), body_) : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); EXPECT_EQ("11", StreamHeadersValue("content-length")); @@ -349,11 +298,10 @@ TEST_P(QuicSimpleServerStreamTest, TestFramingOnePacket) { Invoke(&session_, &MockQuicSimpleServerSession::ConsumeData)); stream_->OnStreamHeaderList(false, kFakeFrameLen, header_list_); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - std::string data = UsesHttp3() ? header + body_ : body_; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); + std::string data = + UsesHttp3() ? absl::StrCat(header.AsStringView(), body_) : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); EXPECT_EQ("11", StreamHeadersValue("content-length")); @@ -374,10 +322,15 @@ TEST_P(QuicSimpleServerStreamTest, SendQuicRstStreamNoErrorInStopReading) { stream_->CloseWriteSide(); if (session_.version().UsesHttp3()) { - EXPECT_CALL(session_, MaybeSendStopSendingFrame(_, QUIC_STREAM_NO_ERROR)) + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))) .Times(1); } else { - EXPECT_CALL(session_, MaybeSendRstStreamFrame(_, QUIC_STREAM_NO_ERROR, _)) + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_NO_ERROR), _)) .Times(1); } stream_->StopReading(); @@ -396,20 +349,20 @@ TEST_P(QuicSimpleServerStreamTest, TestFramingExtraData) { EXPECT_CALL(session_, WritevData(_, kErrorLength, _, FIN, _, _)); stream_->OnStreamHeaderList(false, kFakeFrameLen, header_list_); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - std::string data = UsesHttp3() ? header + body_ : body_; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); + std::string data = + UsesHttp3() ? absl::StrCat(header.AsStringView(), body_) : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); // Content length is still 11. This will register as an error and we won't // accept the bytes. - header_length = - HttpEncoder::SerializeDataFrameHeader(large_body.length(), &buffer); - header = std::string(buffer.get(), header_length); - std::string data2 = UsesHttp3() ? header + large_body : large_body; + header = HttpEncoder::SerializeDataFrameHeader(large_body.length(), + SimpleBufferAllocator::Get()); + std::string data2 = UsesHttp3() + ? absl::StrCat(header.AsStringView(), large_body) + : large_body; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/true, data.size(), data2)); EXPECT_EQ("11", StreamHeadersValue("content-length")); @@ -428,9 +381,8 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithIllegalResponseStatus) { response_headers_[":status"] = "200 OK"; response_headers_["content-length"] = "5"; std::string body = "Yummm"; - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); memory_cache_backend_.AddResponse("www.google.com", "/bar", std::move(response_headers_), body); @@ -440,7 +392,7 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithIllegalResponseStatus) { InSequence s; EXPECT_CALL(*stream_, WriteHeadersMock(false)); if (UsesHttp3()) { - EXPECT_CALL(session_, WritevData(_, header_length, _, NO_FIN, _, _)); + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); } EXPECT_CALL(session_, WritevData(_, kErrorLength, _, FIN, _, _)); @@ -461,9 +413,8 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithIllegalResponseStatus2) { response_headers_["content-length"] = "5"; std::string body = "Yummm"; - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); memory_cache_backend_.AddResponse("www.google.com", "/bar", std::move(response_headers_), body); @@ -473,7 +424,7 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithIllegalResponseStatus2) { InSequence s; EXPECT_CALL(*stream_, WriteHeadersMock(false)); if (UsesHttp3()) { - EXPECT_CALL(session_, WritevData(_, header_length, _, NO_FIN, _, _)); + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); } EXPECT_CALL(session_, WritevData(_, kErrorLength, _, FIN, _, _)); @@ -506,11 +457,16 @@ TEST_P(QuicSimpleServerStreamTest, SendPushResponseWith404Response) { InSequence s; if (session_.version().UsesHttp3()) { - EXPECT_CALL(session_, MaybeSendStopSendingFrame(promised_stream->id(), - QUIC_STREAM_CANCELLED)); + EXPECT_CALL(session_, + MaybeSendStopSendingFrame( + promised_stream->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED))); } - EXPECT_CALL(session_, MaybeSendRstStreamFrame(promised_stream->id(), - QUIC_STREAM_CANCELLED, 0)); + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + promised_stream->id(), + QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED), 0)); promised_stream->DoSendResponse(); } @@ -526,9 +482,8 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithValidHeaders) { response_headers_["content-length"] = "5"; std::string body = "Yummm"; - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); memory_cache_backend_.AddResponse("www.google.com", "/bar", std::move(response_headers_), body); @@ -537,7 +492,7 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithValidHeaders) { InSequence s; EXPECT_CALL(*stream_, WriteHeadersMock(false)); if (UsesHttp3()) { - EXPECT_CALL(session_, WritevData(_, header_length, _, NO_FIN, _, _)); + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); } EXPECT_CALL(session_, WritevData(_, body.length(), _, FIN, _, _)); @@ -546,46 +501,6 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithValidHeaders) { EXPECT_TRUE(stream_->write_side_closed()); } -TEST_P(QuicSimpleServerStreamTest, SendResponseWithPushResources) { - // Tests that if a response has push resources to be send, SendResponse() will - // call PromisePushResources() to handle these resources. - - // Add a request and response with valid headers into cache. - std::string host = "www.google.com"; - std::string request_path = "/foo"; - std::string body = "Yummm"; - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); - QuicBackendResponse::ServerPushInfo push_info( - QuicUrl(host, "/bar"), spdy::Http2HeaderBlock(), - QuicStream::kDefaultPriority, "Push body"); - std::list push_resources; - push_resources.push_back(push_info); - memory_cache_backend_.AddSimpleResponseWithServerPushResources( - host, request_path, 200, body, push_resources); - - spdy::Http2HeaderBlock* request_headers = stream_->mutable_headers(); - (*request_headers)[":path"] = request_path; - (*request_headers)[":authority"] = host; - (*request_headers)[":method"] = "GET"; - - QuicStreamPeer::SetFinReceived(stream_); - InSequence s; - EXPECT_CALL(session_, PromisePushResourcesMock( - host + request_path, _, - GetNthClientInitiatedBidirectionalStreamId( - connection_->transport_version(), 0), - _, _)); - EXPECT_CALL(*stream_, WriteHeadersMock(false)); - if (UsesHttp3()) { - EXPECT_CALL(session_, WritevData(_, header_length, _, NO_FIN, _, _)); - } - EXPECT_CALL(session_, WritevData(_, body.length(), _, FIN, _, _)); - stream_->DoSendResponse(); - EXPECT_EQ(*request_headers, session_.original_request_headers_); -} - TEST_P(QuicSimpleServerStreamTest, SendResponseWithEarlyHints) { std::string host = "www.google.com"; std::string request_path = "/foo"; @@ -597,9 +512,8 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithEarlyHints) { (*request_headers)[":authority"] = host; (*request_headers)[":method"] = "GET"; - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body.length(), &buffer); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body.length(), SimpleBufferAllocator::Get()); std::vector early_hints; // Add two Early Hints. const size_t kNumEarlyHintsResponses = 2; @@ -621,7 +535,7 @@ TEST_P(QuicSimpleServerStreamTest, SendResponseWithEarlyHints) { } EXPECT_CALL(*stream_, WriteHeadersMock(false)); if (UsesHttp3()) { - EXPECT_CALL(session_, WritevData(_, header_length, _, NO_FIN, _, _)); + EXPECT_CALL(session_, WritevData(_, header.size(), _, NO_FIN, _, _)); } EXPECT_CALL(session_, WritevData(_, body.length(), _, FIN, _, _)); @@ -667,9 +581,8 @@ TEST_P(QuicSimpleServerStreamTest, PushResponseOnServerInitiatedStream) { response_headers_[":status"] = "200"; response_headers_["content-length"] = "5"; const std::string kBody = "Hello"; - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(kBody.length(), &buffer); + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); memory_cache_backend_.AddResponse(kHost, kPath, std::move(response_headers_), kBody); @@ -679,7 +592,7 @@ TEST_P(QuicSimpleServerStreamTest, PushResponseOnServerInitiatedStream) { EXPECT_CALL(*server_initiated_stream, WriteHeadersMock(false)); if (UsesHttp3()) { - EXPECT_CALL(session_, WritevData(kServerInitiatedStreamId, header_length, _, + EXPECT_CALL(session_, WritevData(kServerInitiatedStreamId, header.size(), _, NO_FIN, _, _)); } EXPECT_CALL(session_, @@ -764,11 +677,14 @@ TEST_P(QuicSimpleServerStreamTest, .Times(AnyNumber()); } - EXPECT_CALL(session_, MaybeSendRstStreamFrame(_, - session_.version().UsesHttp3() - ? QUIC_STREAM_CANCELLED - : QUIC_RST_ACKNOWLEDGEMENT, - _)) + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, + session_.version().UsesHttp3() + ? QuicResetStreamError::FromInternal(QUIC_STREAM_CANCELLED) + : QuicResetStreamError::FromInternal(QUIC_RST_ACKNOWLEDGEMENT), + _)) .Times(1); QuicRstStreamFrame rst_frame(kInvalidControlFrameId, stream_->id(), QUIC_STREAM_CANCELLED, 1234); @@ -787,27 +703,27 @@ TEST_P(QuicSimpleServerStreamTest, TEST_P(QuicSimpleServerStreamTest, InvalidHeadersWithFin) { char arr[] = { - 0x3a, 0x68, 0x6f, 0x73, // :hos - 0x74, 0x00, 0x00, 0x00, // t... - 0x00, 0x00, 0x00, 0x00, // .... - 0x07, 0x3a, 0x6d, 0x65, // .:me - 0x74, 0x68, 0x6f, 0x64, // thod - 0x00, 0x00, 0x00, 0x03, // .... - 0x47, 0x45, 0x54, 0x00, // GET. - 0x00, 0x00, 0x05, 0x3a, // ...: - 0x70, 0x61, 0x74, 0x68, // path - 0x00, 0x00, 0x00, 0x04, // .... - 0x2f, 0x66, 0x6f, 0x6f, // /foo - 0x00, 0x00, 0x00, 0x07, // .... - 0x3a, 0x73, 0x63, 0x68, // :sch - 0x65, 0x6d, 0x65, 0x00, // eme. - 0x00, 0x00, 0x00, 0x00, // .... - 0x00, 0x00, 0x08, 0x3a, // ...: - 0x76, 0x65, 0x72, 0x73, // vers - 0x96, 0x6f, 0x6e, 0x00, // on. - 0x00, 0x00, 0x08, 0x48, // ...H - 0x54, 0x54, 0x50, 0x2f, // TTP/ - 0x31, 0x2e, 0x31, // 1.1 + 0x3a, 0x68, 0x6f, 0x73, // :hos + 0x74, 0x00, 0x00, 0x00, // t... + 0x00, 0x00, 0x00, 0x00, // .... + 0x07, 0x3a, 0x6d, 0x65, // .:me + 0x74, 0x68, 0x6f, 0x64, // thod + 0x00, 0x00, 0x00, 0x03, // .... + 0x47, 0x45, 0x54, 0x00, // GET. + 0x00, 0x00, 0x05, 0x3a, // ...: + 0x70, 0x61, 0x74, 0x68, // path + 0x00, 0x00, 0x00, 0x04, // .... + 0x2f, 0x66, 0x6f, 0x6f, // /foo + 0x00, 0x00, 0x00, 0x07, // .... + 0x3a, 0x73, 0x63, 0x68, // :sch + 0x65, 0x6d, 0x65, 0x00, // eme. + 0x00, 0x00, 0x00, 0x00, // .... + 0x00, 0x00, 0x08, 0x3a, // ...: + 0x76, 0x65, 0x72, 0x73, // vers + '\x96', 0x6f, 0x6e, 0x00, // on. + 0x00, 0x00, 0x08, 0x48, // ...H + 0x54, 0x54, 0x50, 0x2f, // TTP/ + 0x31, 0x2e, 0x31, // 1.1 }; absl::string_view data(arr, ABSL_ARRAYSIZE(arr)); QuicStreamFrame frame(stream_->id(), true, 0, data); @@ -822,18 +738,17 @@ TEST_P(QuicSimpleServerStreamTest, ConnectSendsResponseBeforeFinReceived) { QuicHeaderList header_list; header_list.OnHeaderBlockStart(); header_list.OnHeader(":authority", "www.google.com:4433"); - header_list.OnHeader(":method", "CONNECT-SILLY"); + header_list.OnHeader(":method", "CONNECT"); header_list.OnHeaderBlockEnd(128, 128); EXPECT_CALL(*stream_, WriteHeadersMock(/*fin=*/false)); stream_->OnStreamHeaderList(/*fin=*/false, kFakeFrameLen, header_list); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - std::string data = UsesHttp3() ? header + body_ : body_; + QuicBuffer header = HttpEncoder::SerializeDataFrameHeader( + body_.length(), SimpleBufferAllocator::Get()); + std::string data = + UsesHttp3() ? absl::StrCat(header.AsStringView(), body_) : body_; stream_->OnStreamFrame( QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); - EXPECT_EQ("CONNECT-SILLY", StreamHeadersValue(":method")); + EXPECT_EQ("CONNECT", StreamHeadersValue(":method")); EXPECT_EQ(body_, StreamBody()); EXPECT_TRUE(stream_->send_response_was_called()); EXPECT_FALSE(stream_->send_error_response_was_called()); @@ -846,21 +761,25 @@ TEST_P(QuicSimpleServerStreamTest, ConnectWithInvalidHeader) { QuicHeaderList header_list; header_list.OnHeaderBlockStart(); header_list.OnHeader(":authority", "www.google.com:4433"); - header_list.OnHeader(":method", "CONNECT-SILLY"); + header_list.OnHeader(":method", "CONNECT"); // QUIC requires lower-case header names. header_list.OnHeader("InVaLiD-HeAdEr", "Well that's just wrong!"); header_list.OnHeaderBlockEnd(128, 128); + + if (UsesHttp3()) { + EXPECT_CALL(session_, + MaybeSendStopSendingFrame(_, QuicResetStreamError::FromInternal( + QUIC_STREAM_NO_ERROR))) + .Times(1); + } else { + EXPECT_CALL( + session_, + MaybeSendRstStreamFrame( + _, QuicResetStreamError::FromInternal(QUIC_STREAM_NO_ERROR), _)) + .Times(1); + } EXPECT_CALL(*stream_, WriteHeadersMock(/*fin=*/false)); stream_->OnStreamHeaderList(/*fin=*/false, kFakeFrameLen, header_list); - std::unique_ptr buffer; - QuicByteCount header_length = - HttpEncoder::SerializeDataFrameHeader(body_.length(), &buffer); - std::string header = std::string(buffer.get(), header_length); - std::string data = UsesHttp3() ? header + body_ : body_; - stream_->OnStreamFrame( - QuicStreamFrame(stream_->id(), /*fin=*/false, /*offset=*/0, data)); - EXPECT_EQ("CONNECT-SILLY", StreamHeadersValue(":method")); - EXPECT_EQ(body_, StreamBody()); EXPECT_FALSE(stream_->send_response_was_called()); EXPECT_TRUE(stream_->send_error_response_was_called()); } diff --git a/gquiche/quic/tools/quic_spdy_client_base.cc b/gquiche/quic/tools/quic_spdy_client_base.cc index 553dbf6c..73ebdf34 100644 --- a/gquiche/quic/tools/quic_spdy_client_base.cc +++ b/gquiche/quic/tools/quic_spdy_client_base.cc @@ -13,7 +13,7 @@ #include "gquiche/quic/core/quic_server_id.h" #include "gquiche/quic/platform/api/quic_flags.h" #include "gquiche/quic/platform/api/quic_logging.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" using spdy::Http2HeaderBlock; @@ -67,6 +67,10 @@ const QuicSpdyClientSession* QuicSpdyClientBase::client_session() const { } void QuicSpdyClientBase::InitializeSession() { + if (max_inbound_header_list_size_ > 0) { + client_session()->set_max_inbound_header_list_size( + max_inbound_header_list_size_); + } client_session()->Initialize(); client_session()->CryptoConnect(); } diff --git a/gquiche/quic/tools/quic_spdy_client_base.h b/gquiche/quic/tools/quic_spdy_client_base.h index 4baa06a1..053099aa 100644 --- a/gquiche/quic/tools/quic_spdy_client_base.h +++ b/gquiche/quic/tools/quic_spdy_client_base.h @@ -144,11 +144,20 @@ class QuicSpdyClientBase : public QuicClientBase, } bool enable_web_transport() const { return enable_web_transport_; } + void set_use_datagram_contexts(bool use_datagram_contexts) { + use_datagram_contexts_ = use_datagram_contexts; + } + bool use_datagram_contexts() const { return use_datagram_contexts_; } + // QuicClientBase methods. bool goaway_received() const override; bool EarlyDataAccepted() override; bool ReceivedInchoateReject() override; + void set_max_inbound_header_list_size(size_t size) { + max_inbound_header_list_size_ = size; + } + protected: int GetNumSentClientHellosFromSession() override; int GetNumReceivedServerConfigUpdatesFromSession() override; @@ -223,6 +232,10 @@ class QuicSpdyClientBase : public QuicClientBase, bool drop_response_body_ = false; bool enable_web_transport_ = false; + bool use_datagram_contexts_ = false; + // If not zero, used to set client's max inbound header size before session + // initialize. + size_t max_inbound_header_list_size_ = 0; }; } // namespace quic diff --git a/gquiche/quic/tools/quic_toy_client.cc b/gquiche/quic/tools/quic_toy_client.cc index 15586527..816621cf 100644 --- a/gquiche/quic/tools/quic_toy_client.cc +++ b/gquiche/quic/tools/quic_toy_client.cc @@ -42,6 +42,7 @@ #include "gquiche/quic/tools/quic_toy_client.h" +#include #include #include #include @@ -51,6 +52,7 @@ #include "absl/strings/escaping.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "gquiche/quic/core/crypto/quic_client_session_cache.h" #include "gquiche/quic/core/quic_packets.h" #include "gquiche/quic/core/quic_server_id.h" #include "gquiche/quic/core/quic_utils.h" @@ -61,7 +63,7 @@ #include "gquiche/quic/platform/api/quic_system_event_loop.h" #include "gquiche/quic/tools/fake_proof_verifier.h" #include "gquiche/quic/tools/quic_url.h" -#include "gquiche/common/platform/api/quiche_text_utils.h" +#include "gquiche/common/quiche_text_utils.h" namespace { @@ -182,6 +184,16 @@ DEFINE_QUIC_COMMAND_LINE_FLAG(bool, false, "If true, don't verify the server certificate."); +DEFINE_QUIC_COMMAND_LINE_FLAG( + std::string, default_client_cert, "", + "The path to the file containing PEM-encoded client default certificate to " + "be sent to the server, if server requested client certs."); + +DEFINE_QUIC_COMMAND_LINE_FLAG( + std::string, default_client_cert_key, "", + "The path to the file containing PEM-encoded private key of the client's " + "default certificate for signing, if server requested client certs."); + DEFINE_QUIC_COMMAND_LINE_FLAG( bool, drop_response_body, @@ -194,6 +206,12 @@ DEFINE_QUIC_COMMAND_LINE_FLAG( false, "If true, do not change local port after each request."); +DEFINE_QUIC_COMMAND_LINE_FLAG(bool, + one_connection_per_request, + false, + "If true, close the connection after each " + "request. This allows testing 0-RTT."); + DEFINE_QUIC_COMMAND_LINE_FLAG(int32_t, server_connection_id_length, -1, @@ -204,7 +222,50 @@ DEFINE_QUIC_COMMAND_LINE_FLAG(int32_t, -1, "Length of the client connection ID used."); +DEFINE_QUIC_COMMAND_LINE_FLAG(int32_t, max_time_before_crypto_handshake_ms, + 10000, + "Max time to wait before handshake completes."); + +DEFINE_QUIC_COMMAND_LINE_FLAG(int32_t, max_inbound_header_list_size, 128 * 1024, + "Max inbound header list size. 0 means default."); + namespace quic { +namespace { + +// Creates a ClientProofSource which only contains a default client certificate. +// Return nullptr for failure. +std::unique_ptr CreateTestClientProofSource( + absl::string_view default_client_cert_file, + absl::string_view default_client_cert_key_file) { + std::ifstream cert_stream(std::string{default_client_cert_file}, + std::ios::binary); + std::vector certs = + CertificateView::LoadPemFromStream(&cert_stream); + if (certs.empty()) { + std::cerr << "Failed to load client certs." << std::endl; + return nullptr; + } + + std::ifstream key_stream(std::string{default_client_cert_key_file}, + std::ios::binary); + std::unique_ptr private_key = + CertificatePrivateKey::LoadPemFromStream(&key_stream); + if (private_key == nullptr) { + std::cerr << "Failed to load client cert key." << std::endl; + return nullptr; + } + + auto proof_source = std::make_unique(); + proof_source->AddCertAndKey( + {"*"}, + QuicReferenceCountedPointer( + new ClientProofSource::Chain(certs)), + std::move(*private_key)); + + return proof_source; +} + +} // namespace QuicToyClient::QuicToyClient(ClientFactory* client_factory) : client_factory_(client_factory) {} @@ -260,6 +321,10 @@ int QuicToyClient::SendRequestsAndPrintResponses( } else { proof_verifier = quic::CreateDefaultProofVerifier(url.host()); } + std::unique_ptr session_cache; + if (num_requests > 1 && GetQuicFlag(FLAGS_one_connection_per_request)) { + session_cache = std::make_unique(); + } QuicConfig config; std::string connection_options_string = GetQuicFlag(FLAGS_connection_options); @@ -282,6 +347,8 @@ int QuicToyClient::SendRequestsAndPrintResponses( config.custom_transport_parameters_to_send()[kCustomParameter] = custom_value; } + config.set_max_time_before_crypto_handshake(QuicTime::Delta::FromMilliseconds( + GetQuicFlag(FLAGS_max_time_before_crypto_handshake_ms))); int address_family_for_lookup = AF_UNSPEC; if (GetQuicFlag(FLAGS_ip_version_for_host_lookup) == "4") { @@ -293,13 +360,25 @@ int QuicToyClient::SendRequestsAndPrintResponses( // Build the client, and try to connect. std::unique_ptr client = client_factory_->CreateClient( url.host(), host, address_family_for_lookup, port, versions, config, - std::move(proof_verifier)); + std::move(proof_verifier), std::move(session_cache)); if (client == nullptr) { std::cerr << "Failed to create client." << std::endl; return 1; } + if (!GetQuicFlag(FLAGS_default_client_cert).empty() && + !GetQuicFlag(FLAGS_default_client_cert_key).empty()) { + std::unique_ptr proof_source = + CreateTestClientProofSource(GetQuicFlag(FLAGS_default_client_cert), + GetQuicFlag(FLAGS_default_client_cert_key)); + if (proof_source == nullptr) { + std::cerr << "Failed to create client proof source." << std::endl; + return 1; + } + client->crypto_config()->set_proof_source(std::move(proof_source)); + } + int32_t initial_mtu = GetQuicFlag(FLAGS_initial_mtu); client->set_initial_max_packet_length( initial_mtu != 0 ? initial_mtu : quic::kDefaultMaxPacketSize); @@ -314,6 +393,11 @@ int QuicToyClient::SendRequestsAndPrintResponses( if (client_connection_id_length >= 0) { client->set_client_connection_id_length(client_connection_id_length); } + const size_t max_inbound_header_list_size = + GetQuicFlag(FLAGS_max_inbound_header_list_size); + if (max_inbound_header_list_size > 0) { + client->set_max_inbound_header_list_size(max_inbound_header_list_size); + } if (!client->Initialize()) { std::cerr << "Failed to initialize client." << std::endl; return 1; @@ -356,7 +440,8 @@ int QuicToyClient::SendRequestsAndPrintResponses( if (sp.empty()) { continue; } - std::vector kv = absl::StrSplit(sp, ':'); + std::vector kv = + absl::StrSplit(sp, absl::MaxSplits(':', 1)); QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&kv[0]); QuicheTextUtils::RemoveLeadingAndTrailingWhitespace(&kv[1]); header_block[kv[0]] = kv[1]; @@ -429,11 +514,26 @@ int QuicToyClient::SendRequestsAndPrintResponses( return 1; } - // Change the ephemeral port if there are more requests to do. - if (!GetQuicFlag(FLAGS_disable_port_changes) && i + 1 < num_requests) { - if (!client->ChangeEphemeralPort()) { - std::cerr << "Failed to change ephemeral port." << std::endl; - return 1; + if (i + 1 < num_requests) { // There are more requests to perform. + if (GetQuicFlag(FLAGS_one_connection_per_request)) { + std::cout << "Disconnecting client between requests." << std::endl; + client->Disconnect(); + if (!client->Initialize()) { + std::cerr << "Failed to reinitialize client between requests." + << std::endl; + return 1; + } + if (!client->Connect()) { + std::cerr << "Failed to reconnect client between requests." + << std::endl; + return 1; + } + } else if (!GetQuicFlag(FLAGS_disable_port_changes)) { + // Change the ephemeral port. + if (!client->ChangeEphemeralPort()) { + std::cerr << "Failed to change ephemeral port." << std::endl; + return 1; + } } } } diff --git a/gquiche/quic/tools/quic_toy_client.h b/gquiche/quic/tools/quic_toy_client.h index c464721f..fd90ce46 100644 --- a/gquiche/quic/tools/quic_toy_client.h +++ b/gquiche/quic/tools/quic_toy_client.h @@ -29,7 +29,8 @@ class QuicToyClient { uint16_t port, ParsedQuicVersionVector versions, const QuicConfig& config, - std::unique_ptr verifier) = 0; + std::unique_ptr verifier, + std::unique_ptr session_cache) = 0; }; // Constructs a new toy client that will use |client_factory| to create the diff --git a/gquiche/quic/tools/quic_toy_server.cc b/gquiche/quic/tools/quic_toy_server.cc index 4dd3f9a5..1f0a7987 100644 --- a/gquiche/quic/tools/quic_toy_server.cc +++ b/gquiche/quic/tools/quic_toy_server.cc @@ -46,6 +46,11 @@ DEFINE_QUIC_COMMAND_LINE_FLAG( "QUIC versions to enable, e.g. \"h3-25,h3-27\". If not set, then all " "available versions are enabled."); +DEFINE_QUIC_COMMAND_LINE_FLAG(bool, + enable_webtransport, + false, + "If true, WebTransport support is enabled."); + namespace quic { std::unique_ptr @@ -58,6 +63,9 @@ QuicToyServer::MemoryCacheBackendFactory::CreateBackend() { memory_cache_backend->InitializeBackend( GetQuicFlag(FLAGS_quic_response_cache_dir)); } + if (GetQuicFlag(FLAGS_enable_webtransport)) { + memory_cache_backend->EnableWebTransport(); + } return memory_cache_backend; } diff --git a/gquiche/quic/tools/quic_transport_simple_server_dispatcher.cc b/gquiche/quic/tools/quic_transport_simple_server_dispatcher.cc index 6628146b..f24be23f 100644 --- a/gquiche/quic/tools/quic_transport_simple_server_dispatcher.cc +++ b/gquiche/quic/tools/quic_transport_simple_server_dispatcher.cc @@ -37,10 +37,9 @@ std::unique_ptr QuicTransportSimpleServerDispatcher::CreateQuicSession( QuicConnectionId server_connection_id, const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view /*alpn*/, + const QuicSocketAddress& peer_address, absl::string_view /*alpn*/, const ParsedQuicVersion& version, - absl::string_view /*sni*/) { + const ParsedClientHello& /*parsed_chlo*/) { auto connection = std::make_unique( server_connection_id, self_address, peer_address, helper(), alarm_factory(), writer(), diff --git a/gquiche/quic/tools/quic_transport_simple_server_dispatcher.h b/gquiche/quic/tools/quic_transport_simple_server_dispatcher.h index 4ab055d2..7344b474 100644 --- a/gquiche/quic/tools/quic_transport_simple_server_dispatcher.h +++ b/gquiche/quic/tools/quic_transport_simple_server_dispatcher.h @@ -30,10 +30,9 @@ class QuicTransportSimpleServerDispatcher : public QuicDispatcher { std::unique_ptr CreateQuicSession( QuicConnectionId server_connection_id, const QuicSocketAddress& self_address, - const QuicSocketAddress& peer_address, - absl::string_view alpn, + const QuicSocketAddress& peer_address, absl::string_view alpn, const ParsedQuicVersion& version, - absl::string_view sni) override; + const ParsedClientHello& parsed_chlo) override; std::vector accepted_origins_; }; diff --git a/gquiche/quic/tools/quic_transport_simple_server_session.h b/gquiche/quic/tools/quic_transport_simple_server_session.h index f4f8de6f..66e11d08 100644 --- a/gquiche/quic/tools/quic_transport_simple_server_session.h +++ b/gquiche/quic/tools/quic_transport_simple_server_session.h @@ -73,7 +73,7 @@ class QuicTransportSimpleServerSession size_t pending_outgoing_bidirectional_streams_ = 0u; Mode mode_; std::vector accepted_origins_; - QuicCircularDeque streams_to_echo_back_; + quiche::QuicheCircularDeque streams_to_echo_back_; }; } // namespace quic diff --git a/gquiche/quic/tools/simple_ticket_crypter.cc b/gquiche/quic/tools/simple_ticket_crypter.cc index 7513b10f..e52fdb81 100644 --- a/gquiche/quic/tools/simple_ticket_crypter.cc +++ b/gquiche/quic/tools/simple_ticket_crypter.cc @@ -38,7 +38,12 @@ size_t SimpleTicketCrypter::MaxOverhead() { return kEpochSize + kIVSize + kAuthTagSize; } -std::vector SimpleTicketCrypter::Encrypt(absl::string_view in) { +std::vector SimpleTicketCrypter::Encrypt( + absl::string_view in, absl::string_view encryption_key) { + // This class is only used in Chromium, in which the |encryption_key| argument + // will never be populated and an internally-cached key should be used for + // encrypting tickets. + QUICHE_DCHECK(encryption_key.empty()); MaybeRotateKeys(); std::vector out(in.size() + MaxOverhead()); out[0] = key_epoch_; diff --git a/gquiche/quic/tools/simple_ticket_crypter.h b/gquiche/quic/tools/simple_ticket_crypter.h index 8dfdea13..362cfc9c 100644 --- a/gquiche/quic/tools/simple_ticket_crypter.h +++ b/gquiche/quic/tools/simple_ticket_crypter.h @@ -24,7 +24,8 @@ class QUIC_NO_EXPORT SimpleTicketCrypter ~SimpleTicketCrypter() override; size_t MaxOverhead() override; - std::vector Encrypt(absl::string_view in) override; + std::vector Encrypt(absl::string_view in, + absl::string_view encryption_key) override; void Decrypt( absl::string_view in, std::unique_ptr callback) override; diff --git a/gquiche/quic/tools/simple_ticket_crypter_test.cc b/gquiche/quic/tools/simple_ticket_crypter_test.cc index 9f6360a4..268f9f87 100644 --- a/gquiche/quic/tools/simple_ticket_crypter_test.cc +++ b/gquiche/quic/tools/simple_ticket_crypter_test.cc @@ -42,7 +42,7 @@ class SimpleTicketCrypterTest : public QuicTest { TEST_F(SimpleTicketCrypterTest, EncryptDecrypt) { std::vector plaintext = {1, 2, 3, 4, 5}; std::vector ciphertext = - ticket_crypter_.Encrypt(StringPiece(plaintext)); + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); EXPECT_NE(plaintext, ciphertext); std::vector out_plaintext; @@ -54,16 +54,16 @@ TEST_F(SimpleTicketCrypterTest, EncryptDecrypt) { TEST_F(SimpleTicketCrypterTest, CiphertextsDiffer) { std::vector plaintext = {1, 2, 3, 4, 5}; std::vector ciphertext1 = - ticket_crypter_.Encrypt(StringPiece(plaintext)); + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); std::vector ciphertext2 = - ticket_crypter_.Encrypt(StringPiece(plaintext)); + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); EXPECT_NE(ciphertext1, ciphertext2); } TEST_F(SimpleTicketCrypterTest, DecryptionFailureWithModifiedCiphertext) { std::vector plaintext = {1, 2, 3, 4, 5}; std::vector ciphertext = - ticket_crypter_.Encrypt(StringPiece(plaintext)); + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); EXPECT_NE(plaintext, ciphertext); // Check that a bit flip in any byte will cause a decryption failure. @@ -88,7 +88,7 @@ TEST_F(SimpleTicketCrypterTest, DecryptionFailureWithEmptyCiphertext) { TEST_F(SimpleTicketCrypterTest, KeyRotation) { std::vector plaintext = {1, 2, 3}; std::vector ciphertext = - ticket_crypter_.Encrypt(StringPiece(plaintext)); + ticket_crypter_.Encrypt(StringPiece(plaintext), {}); EXPECT_FALSE(ciphertext.empty()); // Advance the clock 8 days, so the key used for |ciphertext| is now the diff --git a/gquiche/quic/tools/web_transport_test_visitors.h b/gquiche/quic/tools/web_transport_test_visitors.h index c77515c0..29ae414b 100644 --- a/gquiche/quic/tools/web_transport_test_visitors.h +++ b/gquiche/quic/tools/web_transport_test_visitors.h @@ -7,8 +7,10 @@ #include +#include "gquiche/quic/core/quic_simple_buffer_allocator.h" #include "gquiche/quic/core/web_transport_interface.h" #include "gquiche/quic/platform/api/quic_logging.h" +#include "gquiche/common/quiche_circular_deque.h" namespace quic { @@ -27,6 +29,10 @@ class WebTransportDiscardVisitor : public WebTransportStreamVisitor { void OnCanWrite() override {} + void OnResetStreamReceived(WebTransportStreamError /*error*/) override {} + void OnStopSendingReceived(WebTransportStreamError /*error*/) override {} + void OnWriteSideInDataRecvdState() override {} + private: WebTransportStream* stream_; }; @@ -49,6 +55,10 @@ class WebTransportBidirectionalEchoVisitor : public WebTransportStreamVisitor { } void OnCanWrite() override { + if (stop_sending_received_) { + return; + } + if (!buffer_.empty()) { bool success = stream_->Write(buffer_); QUIC_DVLOG(1) << "Attempted writing on WebTransport bidirectional stream " @@ -67,10 +77,26 @@ class WebTransportBidirectionalEchoVisitor : public WebTransportStreamVisitor { } } + void OnResetStreamReceived(WebTransportStreamError /*error*/) override { + // Send FIN in response to a stream reset. We want to test that we can + // operate one side of the stream cleanly while the other is reset, thus + // replying with a FIN rather than a RESET_STREAM is more appropriate here. + send_fin_ = true; + OnCanWrite(); + } + void OnStopSendingReceived(WebTransportStreamError /*error*/) override { + stop_sending_received_ = true; + } + void OnWriteSideInDataRecvdState() override {} + + protected: + WebTransportStream* stream() { return stream_; } + private: WebTransportStream* stream_; std::string buffer_; bool send_fin_ = false; + bool stop_sending_received_ = false; }; // Buffers all of the data and calls |callback| with the entirety of the stream @@ -98,6 +124,10 @@ class WebTransportUnidirectionalEchoReadVisitor void OnCanWrite() override { QUIC_NOTREACHED(); } + void OnResetStreamReceived(WebTransportStreamError /*error*/) override {} + void OnStopSendingReceived(WebTransportStreamError /*error*/) override {} + void OnWriteSideInDataRecvdState() override {} + private: WebTransportStream* stream_; std::string buffer_; @@ -127,11 +157,107 @@ class WebTransportUnidirectionalEchoWriteVisitor QUICHE_DCHECK(fin_sent); } + void OnResetStreamReceived(WebTransportStreamError /*error*/) override {} + void OnStopSendingReceived(WebTransportStreamError /*error*/) override {} + void OnWriteSideInDataRecvdState() override {} + private: WebTransportStream* stream_; std::string data_; }; +// A session visitor which sets unidirectional or bidirectional stream visitors +// to echo. +class EchoWebTransportSessionVisitor : public WebTransportVisitor { + public: + EchoWebTransportSessionVisitor(WebTransportSession* session) + : session_(session) {} + + void OnSessionReady(const spdy::SpdyHeaderBlock&) override { + if (session_->CanOpenNextOutgoingBidirectionalStream()) { + OnCanCreateNewOutgoingBidirectionalStream(); + } + } + + void OnSessionClosed(WebTransportSessionError /*error_code*/, + const std::string& /*error_message*/) override {} + + void OnIncomingBidirectionalStreamAvailable() override { + while (true) { + WebTransportStream* stream = + session_->AcceptIncomingBidirectionalStream(); + if (stream == nullptr) { + return; + } + QUIC_DVLOG(1) + << "EchoWebTransportSessionVisitor received a bidirectional stream " + << stream->GetStreamId(); + stream->SetVisitor( + std::make_unique(stream)); + stream->visitor()->OnCanRead(); + } + } + + void OnIncomingUnidirectionalStreamAvailable() override { + while (true) { + WebTransportStream* stream = + session_->AcceptIncomingUnidirectionalStream(); + if (stream == nullptr) { + return; + } + QUIC_DVLOG(1) + << "EchoWebTransportSessionVisitor received a unidirectional stream"; + stream->SetVisitor( + std::make_unique( + stream, [this](const std::string& data) { + streams_to_echo_back_.push_back(data); + TrySendingUnidirectionalStreams(); + })); + stream->visitor()->OnCanRead(); + } + } + + void OnDatagramReceived(absl::string_view datagram) override { + auto buffer = MakeUniqueBuffer(&allocator_, datagram.size()); + memcpy(buffer.get(), datagram.data(), datagram.size()); + QuicMemSlice slice(std::move(buffer), datagram.size()); + session_->SendOrQueueDatagram(std::move(slice)); + } + + void OnCanCreateNewOutgoingBidirectionalStream() override { + if (!echo_stream_opened_) { + WebTransportStream* stream = session_->OpenOutgoingBidirectionalStream(); + stream->SetVisitor( + std::make_unique(stream)); + echo_stream_opened_ = true; + } + } + void OnCanCreateNewOutgoingUnidirectionalStream() override { + TrySendingUnidirectionalStreams(); + } + + void TrySendingUnidirectionalStreams() { + while (!streams_to_echo_back_.empty() && + session_->CanOpenNextOutgoingUnidirectionalStream()) { + QUIC_DVLOG(1) + << "EchoWebTransportServer echoed a unidirectional stream back"; + WebTransportStream* stream = session_->OpenOutgoingUnidirectionalStream(); + stream->SetVisitor( + std::make_unique( + stream, streams_to_echo_back_.front())); + streams_to_echo_back_.pop_front(); + stream->visitor()->OnCanWrite(); + } + } + + private: + WebTransportSession* session_; + SimpleBufferAllocator allocator_; + bool echo_stream_opened_ = false; + + quiche::QuicheCircularDeque streams_to_echo_back_; +}; + } // namespace quic #endif // QUICHE_QUIC_TOOLS_WEB_TRANSPORT_TEST_VISITORS_H_ diff --git a/gquiche/spdy/core/header_byte_listener_interface.h b/gquiche/spdy/core/header_byte_listener_interface.h new file mode 100644 index 00000000..7ac54718 --- /dev/null +++ b/gquiche/spdy/core/header_byte_listener_interface.h @@ -0,0 +1,22 @@ +#ifndef QUICHE_SPDY_CORE_HEADER_BYTE_LISTENER_INTERFACE_H_ +#define QUICHE_SPDY_CORE_HEADER_BYTE_LISTENER_INTERFACE_H_ + +#include + +#include "gquiche/common/platform/api/quiche_export.h" + +namespace spdy { + +// Listens for the receipt of uncompressed header bytes. +class QUICHE_EXPORT_PRIVATE HeaderByteListenerInterface { + public: + virtual ~HeaderByteListenerInterface() {} + + // Called when a header block has been parsed, with the number of uncompressed + // header bytes parsed from the header block. + virtual void OnHeaderBytesReceived(size_t uncompressed_header_bytes) = 0; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HEADER_BYTE_LISTENER_INTERFACE_H_ diff --git a/gquiche/spdy/core/hpack/hpack_decoder_adapter.cc b/gquiche/spdy/core/hpack/hpack_decoder_adapter.cc index b6969de2..2c9a0edf 100644 --- a/gquiche/spdy/core/hpack/hpack_decoder_adapter.cc +++ b/gquiche/spdy/core/hpack/hpack_decoder_adapter.cc @@ -7,7 +7,6 @@ #include "gquiche/http2/decoder/decode_buffer.h" #include "gquiche/http2/decoder/decode_status.h" #include "gquiche/common/platform/api/quiche_logging.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" using ::http2::DecodeBuffer; @@ -30,6 +29,10 @@ void HpackDecoderAdapter::ApplyHeaderTableSizeSetting(size_t size_setting) { hpack_decoder_.ApplyHeaderTableSizeSetting(size_setting); } +size_t HpackDecoderAdapter::GetCurrentHeaderTableSizeSetting() const { + return hpack_decoder_.GetCurrentHeaderTableSizeSetting(); +} + void HpackDecoderAdapter::HandleControlFrameHeadersStart( SpdyHeadersHandlerInterface* handler) { QUICHE_DVLOG(2) << "HpackDecoderAdapter::HandleControlFrameHeadersStart"; @@ -86,12 +89,8 @@ bool HpackDecoderAdapter::HandleControlFrameHeadersData( return true; } -bool HpackDecoderAdapter::HandleControlFrameHeadersComplete( - size_t* compressed_len) { +bool HpackDecoderAdapter::HandleControlFrameHeadersComplete() { QUICHE_DVLOG(2) << "HpackDecoderAdapter::HandleControlFrameHeadersComplete"; - if (compressed_len != nullptr) { - *compressed_len = listener_adapter_.total_hpack_bytes(); - } if (!hpack_decoder_.EndDecodingBlock()) { QUICHE_DVLOG(3) << "EndDecodingBlock returned false"; error_ = hpack_decoder_.error(); @@ -118,10 +117,6 @@ void HpackDecoderAdapter::set_max_header_block_bytes( max_header_block_bytes_ = max_header_block_bytes; } -size_t HpackDecoderAdapter::EstimateMemoryUsage() const { - return SpdyEstimateMemoryUsage(hpack_decoder_); -} - HpackDecoderAdapter::ListenerAdapter::ListenerAdapter() : handler_(nullptr) {} HpackDecoderAdapter::ListenerAdapter::~ListenerAdapter() = default; diff --git a/gquiche/spdy/core/hpack/hpack_decoder_adapter.h b/gquiche/spdy/core/hpack/hpack_decoder_adapter.h index a38ff6f2..92001668 100644 --- a/gquiche/spdy/core/hpack/hpack_decoder_adapter.h +++ b/gquiche/spdy/core/hpack/hpack_decoder_adapter.h @@ -39,6 +39,9 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderAdapter { // Called upon acknowledgement of SETTINGS_HEADER_TABLE_SIZE. void ApplyHeaderTableSizeSetting(size_t size_setting); + // Returns the most recently applied value of SETTINGS_HEADER_TABLE_SIZE. + size_t GetCurrentHeaderTableSizeSetting() const; + // If a SpdyHeadersHandlerInterface is provided, the decoder will emit // headers to it rather than accumulating them in a SpdyHeaderBlock. // Does not take ownership of the handler, but does use the pointer until @@ -56,10 +59,7 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderAdapter { // buffered block that was accumulated in HandleControlFrameHeadersData(), // to support subsequent calculation of compression percentage. // Discards the handler supplied at the start of decoding the block. - // TODO(jamessynge): Determine if compressed_len is needed; it is used to - // produce UUMA stat Net.SpdyHpackDecompressionPercentage, but only for - // deprecated SPDY3. - bool HandleControlFrameHeadersComplete(size_t* compressed_len); + bool HandleControlFrameHeadersComplete(); // Accessor for the most recently decoded headers block. Valid until the next // call to HandleControlFrameHeadersData(). @@ -67,6 +67,12 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderAdapter { // a SpdyHeadersHandlerInterface. const SpdyHeaderBlock& decoded_block() const; + // Returns the current dynamic table size, including the 32 bytes per entry + // overhead mentioned in RFC 7541 section 4.1. + size_t GetDynamicTableSize() const { + return hpack_decoder_.GetDynamicTableSize(); + } + // Set how much encoded data this decoder is willing to buffer. // TODO(jamessynge): Resolve definition of this value, as it is currently // too tied to a single implementation. We probably want to limit one or more @@ -79,8 +85,6 @@ class QUICHE_EXPORT_PRIVATE HpackDecoderAdapter { // accepted. void set_max_header_block_bytes(size_t max_header_block_bytes); - size_t EstimateMemoryUsage() const; - // Error code if an error has occurred, Error::kOk otherwise. http2::HpackDecodingError error() const { return error_; } diff --git a/gquiche/spdy/core/hpack/hpack_decoder_adapter_test.cc b/gquiche/spdy/core/hpack/hpack_decoder_adapter_test.cc index 8a285821..dd417290 100644 --- a/gquiche/spdy/core/hpack/hpack_decoder_adapter_test.cc +++ b/gquiche/spdy/core/hpack/hpack_decoder_adapter_test.cc @@ -21,12 +21,12 @@ #include "gquiche/http2/test_tools/http2_random.h" #include "gquiche/common/platform/api/quiche_logging.h" #include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/common/quiche_text_utils.h" #include "gquiche/spdy/core/hpack/hpack_constants.h" #include "gquiche/spdy/core/hpack/hpack_encoder.h" #include "gquiche/spdy/core/hpack/hpack_output_stream.h" #include "gquiche/spdy/core/recording_headers_handler.h" #include "gquiche/spdy/core/spdy_test_utils.h" -#include "gquiche/spdy/platform/api/spdy_string_utils.h" using ::http2::HpackEntryType; using ::http2::HpackStringPair; @@ -135,16 +135,14 @@ class HpackDecoderAdapterTest } bool HandleControlFrameHeadersData(absl::string_view str) { - QUICHE_VLOG(3) << "HandleControlFrameHeadersData:\n" << SpdyHexDump(str); + QUICHE_VLOG(3) << "HandleControlFrameHeadersData:\n" + << quiche::QuicheTextUtils::HexDump(str); bytes_passed_in_ += str.size(); return decoder_.HandleControlFrameHeadersData(str.data(), str.size()); } - bool HandleControlFrameHeadersComplete(size_t* size) { - bool rc = decoder_.HandleControlFrameHeadersComplete(size); - if (size != nullptr) { - EXPECT_EQ(*size, bytes_passed_in_); - } + bool HandleControlFrameHeadersComplete() { + bool rc = decoder_.HandleControlFrameHeadersComplete(); return rc; } @@ -171,22 +169,18 @@ class HpackDecoderAdapterTest decode_has_failed_ = true; return false; } - // Want to get out the number of compressed bytes that were decoded, - // so pass in a pointer if no handler. - size_t total_hpack_bytes = 0; if (start_choice_ == START_WITH_HANDLER) { - if (!HandleControlFrameHeadersComplete(nullptr)) { + if (!HandleControlFrameHeadersComplete()) { decode_has_failed_ = true; return false; } - total_hpack_bytes = handler_.compressed_header_bytes(); + EXPECT_EQ(handler_.compressed_header_bytes(), bytes_passed_in_); } else { - if (!HandleControlFrameHeadersComplete(&total_hpack_bytes)) { + if (!HandleControlFrameHeadersComplete()) { decode_has_failed_ = true; return false; } } - EXPECT_EQ(total_hpack_bytes, bytes_passed_in_); if (check_decoded_size && start_choice_ == START_WITH_HANDLER) { EXPECT_EQ(handler_.uncompressed_header_bytes(), SizeOfHeaders(decoded_block())); @@ -273,6 +267,12 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine(::testing::Values(START_WITH_HANDLER), ::testing::Bool())); +TEST_P(HpackDecoderAdapterTest, ApplyHeaderTableSizeSetting) { + EXPECT_EQ(4096u, decoder_.GetCurrentHeaderTableSizeSetting()); + decoder_.ApplyHeaderTableSizeSetting(12 * 1024); + EXPECT_EQ(12288u, decoder_.GetCurrentHeaderTableSizeSetting()); +} + TEST_P(HpackDecoderAdapterTest, AddHeaderDataWithHandleControlFrameHeadersData) { // The hpack decode buffer size is limited in size. This test verifies that @@ -351,8 +351,7 @@ TEST_P(HpackDecoderAdapterTest, HeaderBlockTooLong) { // entire block successfully. HandleControlFrameHeadersStart(); EXPECT_TRUE(HandleControlFrameHeadersData(hbb.buffer())); - size_t total_bytes; - EXPECT_TRUE(HandleControlFrameHeadersComplete(&total_bytes)); + EXPECT_TRUE(HandleControlFrameHeadersComplete()); // When a total byte limit is imposed, the decoder bails before the end of the // block. @@ -386,9 +385,7 @@ TEST_P(HpackDecoderAdapterTest, DecodeWithIncompleteData) { // Add the needed data. EXPECT_TRUE(HandleControlFrameHeadersData("\x04gggs")); - size_t size = 0; - EXPECT_TRUE(HandleControlFrameHeadersComplete(&size)); - EXPECT_EQ(24u, size); + EXPECT_TRUE(HandleControlFrameHeadersComplete()); expected_headers.push_back({"spam", "gggs"}); @@ -428,7 +425,7 @@ TEST_P(HpackDecoderAdapterTest, HandleHeaderRepresentation) { decoder_peer_.HandleHeaderRepresentation("cookie", " fin!"); // Finish and emit all headers. - decoder_.HandleControlFrameHeadersComplete(nullptr); + decoder_.HandleControlFrameHeadersComplete(); // Resulting decoded headers are in the same order as the inputs. EXPECT_THAT( @@ -502,8 +499,8 @@ TEST_P(HpackDecoderAdapterTest, ContextUpdateMaximumSize) { output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); output_stream.AppendUint32(126); - output_stream.TakeString(&input); - EXPECT_TRUE(DecodeHeaderBlock(absl::string_view(input))); + input = output_stream.TakeString(); + EXPECT_TRUE(DecodeHeaderBlock(input)); EXPECT_EQ(126u, decoder_peer_.header_table_size_limit()); } { @@ -512,8 +509,8 @@ TEST_P(HpackDecoderAdapterTest, ContextUpdateMaximumSize) { output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); output_stream.AppendUint32(kDefaultHeaderTableSizeSetting); - output_stream.TakeString(&input); - EXPECT_TRUE(DecodeHeaderBlock(absl::string_view(input))); + input = output_stream.TakeString(); + EXPECT_TRUE(DecodeHeaderBlock(input)); EXPECT_EQ(kDefaultHeaderTableSizeSetting, decoder_peer_.header_table_size_limit()); } @@ -523,8 +520,8 @@ TEST_P(HpackDecoderAdapterTest, ContextUpdateMaximumSize) { output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); output_stream.AppendUint32(kDefaultHeaderTableSizeSetting + 1); - output_stream.TakeString(&input); - EXPECT_FALSE(DecodeHeaderBlock(absl::string_view(input))); + input = output_stream.TakeString(); + EXPECT_FALSE(DecodeHeaderBlock(input)); EXPECT_EQ(kDefaultHeaderTableSizeSetting, decoder_peer_.header_table_size_limit()); } @@ -541,8 +538,8 @@ TEST_P(HpackDecoderAdapterTest, TwoTableSizeUpdates) { output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); output_stream.AppendUint32(122); - output_stream.TakeString(&input); - EXPECT_TRUE(DecodeHeaderBlock(absl::string_view(input))); + input = output_stream.TakeString(); + EXPECT_TRUE(DecodeHeaderBlock(input)); EXPECT_EQ(122u, decoder_peer_.header_table_size_limit()); } } @@ -560,9 +557,9 @@ TEST_P(HpackDecoderAdapterTest, ThreeTableSizeUpdatesError) { output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); output_stream.AppendUint32(15); - output_stream.TakeString(&input); + input = output_stream.TakeString(); - EXPECT_FALSE(DecodeHeaderBlock(absl::string_view(input))); + EXPECT_FALSE(DecodeHeaderBlock(input)); EXPECT_EQ(10u, decoder_peer_.header_table_size_limit()); } } @@ -579,9 +576,9 @@ TEST_P(HpackDecoderAdapterTest, TableSizeUpdateSecondError) { output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); output_stream.AppendUint32(123); - output_stream.TakeString(&input); + input = output_stream.TakeString(); - EXPECT_FALSE(DecodeHeaderBlock(absl::string_view(input))); + EXPECT_FALSE(DecodeHeaderBlock(input)); EXPECT_EQ(kDefaultHeaderTableSizeSetting, decoder_peer_.header_table_size_limit()); } @@ -602,9 +599,9 @@ TEST_P(HpackDecoderAdapterTest, TableSizeUpdateFirstThirdError) { output_stream.AppendPrefix(kHeaderTableSizeUpdateOpcode); output_stream.AppendUint32(125); - output_stream.TakeString(&input); + input = output_stream.TakeString(); - EXPECT_FALSE(DecodeHeaderBlock(absl::string_view(input))); + EXPECT_FALSE(DecodeHeaderBlock(input)); EXPECT_EQ(60u, decoder_peer_.header_table_size_limit()); } } @@ -718,9 +715,8 @@ TEST_P(HpackDecoderAdapterTest, BasicC31) { expected_header_set[":path"] = "/"; expected_header_set[":authority"] = "www.example.com"; - std::string encoded_header_set; - EXPECT_TRUE( - encoder.EncodeHeaderSet(expected_header_set, &encoded_header_set)); + std::string encoded_header_set = + encoder.EncodeHeaderBlock(expected_header_set); EXPECT_TRUE(DecodeHeaderBlock(encoded_header_set)); EXPECT_EQ(expected_header_set, decoded_block()); diff --git a/gquiche/spdy/core/hpack/hpack_encoder.cc b/gquiche/spdy/core/hpack/hpack_encoder.cc index 33a7df0f..07d2fc50 100644 --- a/gquiche/spdy/core/hpack/hpack_encoder.cc +++ b/gquiche/spdy/core/hpack/hpack_encoder.cc @@ -14,7 +14,6 @@ #include "gquiche/spdy/core/hpack/hpack_constants.h" #include "gquiche/spdy/core/hpack/hpack_header_table.h" #include "gquiche/spdy/core/hpack/hpack_output_stream.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" namespace spdy { @@ -85,8 +84,7 @@ HpackEncoder::HpackEncoder() HpackEncoder::~HpackEncoder() = default; -bool HpackEncoder::EncodeHeaderSet(const SpdyHeaderBlock& header_set, - std::string* output) { +std::string HpackEncoder::EncodeHeaderBlock(const SpdyHeaderBlock& header_set) { // Separate header set into pseudo-headers and regular headers. Representations pseudo_headers; Representations regular_headers; @@ -105,11 +103,8 @@ bool HpackEncoder::EncodeHeaderSet(const SpdyHeaderBlock& header_set, } } - { - RepresentationIterator iter(pseudo_headers, regular_headers); - EncodeRepresentations(&iter, output); - } - return true; + RepresentationIterator iter(pseudo_headers, regular_headers); + return EncodeRepresentations(&iter); } void HpackEncoder::ApplyHeaderTableSizeSetting(size_t size_setting) { @@ -124,13 +119,7 @@ void HpackEncoder::ApplyHeaderTableSizeSetting(size_t size_setting) { should_emit_table_size_ = true; } -size_t HpackEncoder::EstimateMemoryUsage() const { - return SpdyEstimateMemoryUsage(header_table_) + - SpdyEstimateMemoryUsage(output_stream_); -} - -void HpackEncoder::EncodeRepresentations(RepresentationIterator* iter, - std::string* output) { +std::string HpackEncoder::EncodeRepresentations(RepresentationIterator* iter) { MaybeEmitTableSize(); while (iter->HasNext()) { const auto header = iter->Next(); @@ -150,7 +139,7 @@ void HpackEncoder::EncodeRepresentations(RepresentationIterator* iter, } } - output_stream_.TakeString(output); + return output_stream_.TakeString(); } void HpackEncoder::EmitIndex(size_t index) { @@ -244,7 +233,7 @@ void HpackEncoder::CookieToCrumbs(const Representation& cookie, cookie_value = cookie_value.substr(first, (last - first) + 1); } for (size_t pos = 0;;) { - size_t end = cookie_value.find(";", pos); + size_t end = cookie_value.find(';', pos); if (end == absl::string_view::npos) { out->push_back(std::make_pair(cookie.first, cookie_value.substr(pos))); @@ -289,9 +278,8 @@ class HpackEncoder::Encoderator : public ProgressiveEncoder { // Returns true iff more remains to encode. bool HasNext() const override { return has_next_; } - // Encodes up to max_encoded_bytes of the current header block into the - // given output string. - void Next(size_t max_encoded_bytes, std::string* output) override; + // Encodes and returns up to max_encoded_bytes of the current header block. + std::string Next(size_t max_encoded_bytes) override; private: HpackEncoder* encoder_; @@ -344,8 +332,7 @@ HpackEncoder::Encoderator::Encoderator(const Representations& representations, encoder_->MaybeEmitTableSize(); } -void HpackEncoder::Encoderator::Next(size_t max_encoded_bytes, - std::string* output) { +std::string HpackEncoder::Encoderator::Next(size_t max_encoded_bytes) { QUICHE_BUG_IF(spdy_bug_61_1, !has_next_) << "Encoderator::Next called with nothing left to encode."; const bool enable_compression = encoder_->enable_compression_; @@ -371,7 +358,7 @@ void HpackEncoder::Encoderator::Next(size_t max_encoded_bytes, } has_next_ = encoder_->output_stream_.size() > max_encoded_bytes; - encoder_->output_stream_.BoundedTakeString(max_encoded_bytes, output); + return encoder_->output_stream_.BoundedTakeString(max_encoded_bytes); } std::unique_ptr HpackEncoder::EncodeHeaderSet( diff --git a/gquiche/spdy/core/hpack/hpack_encoder.h b/gquiche/spdy/core/hpack/hpack_encoder.h index b8a894cc..e2f539c7 100644 --- a/gquiche/spdy/core/hpack/hpack_encoder.h +++ b/gquiche/spdy/core/hpack/hpack_encoder.h @@ -49,9 +49,8 @@ class QUICHE_EXPORT_PRIVATE HpackEncoder { HpackEncoder& operator=(const HpackEncoder&) = delete; ~HpackEncoder(); - // Encodes the given header set into the given string. Returns - // whether or not the encoding was successful. - bool EncodeHeaderSet(const SpdyHeaderBlock& header_set, std::string* output); + // Encodes and returns the given header set as a string. + std::string EncodeHeaderBlock(const SpdyHeaderBlock& header_set); class QUICHE_EXPORT_PRIVATE ProgressiveEncoder { public: @@ -60,9 +59,8 @@ class QUICHE_EXPORT_PRIVATE HpackEncoder { // Returns true iff more remains to encode. virtual bool HasNext() const = 0; - // Encodes up to max_encoded_bytes of the current header block into the - // given output string. - virtual void Next(size_t max_encoded_bytes, std::string* output) = 0; + // Encodes and returns up to max_encoded_bytes of the current header block. + virtual std::string Next(size_t max_encoded_bytes) = 0; }; // Returns a ProgressiveEncoder which must be outlived by both the given @@ -81,6 +79,7 @@ class QUICHE_EXPORT_PRIVATE HpackEncoder { // SETTINGS_HEADER_TABLE_SIZE update from the remote decoding endpoint. void ApplyHeaderTableSizeSetting(size_t size_setting); + // TODO(birenroy): Rename this GetDynamicTableCapacity(). size_t CurrentHeaderTableSizeSetting() const { return header_table_.settings_size_bound(); } @@ -95,8 +94,9 @@ class QUICHE_EXPORT_PRIVATE HpackEncoder { void DisableCompression() { enable_compression_ = false; } - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; + // Returns the current dynamic table size, including the 32 bytes per entry + // overhead mentioned in RFC 7541 section 4.1. + size_t GetDynamicTableSize() const { return header_table_.size(); } private: friend class test::HpackEncoderPeer; @@ -105,7 +105,7 @@ class QUICHE_EXPORT_PRIVATE HpackEncoder { class Encoderator; // Encodes a sequence of header name-value pairs as a single header block. - void EncodeRepresentations(RepresentationIterator* iter, std::string* output); + std::string EncodeRepresentations(RepresentationIterator* iter); // Emits a static/dynamic indexed representation (Section 7.1). void EmitIndex(size_t index); diff --git a/gquiche/spdy/core/hpack/hpack_encoder_test.cc b/gquiche/spdy/core/hpack/hpack_encoder_test.cc index 9f561e80..24c2fb60 100644 --- a/gquiche/spdy/core/hpack/hpack_encoder_test.cc +++ b/gquiche/spdy/core/hpack/hpack_encoder_test.cc @@ -45,7 +45,7 @@ class HpackEncoderPeer { HpackHeaderTablePeer table_peer() { return HpackHeaderTablePeer(table()); } void EmitString(absl::string_view str) { encoder_->EmitString(str); } void TakeString(std::string* out) { - encoder_->output_stream_.TakeString(out); + *out = encoder_->output_stream_.TakeString(); } static void CookieToCrumbs(absl::string_view cookie, std::vector* out) { @@ -71,10 +71,9 @@ class HpackEncoderPeer { // TODO(dahollings): Remove or clean up these methods when deprecating // non-incremental encoding path. - static bool EncodeHeaderSet(HpackEncoder* encoder, - const SpdyHeaderBlock& header_set, - std::string* output) { - return encoder->EncodeHeaderSet(header_set, output); + static std::string EncodeHeaderBlock(HpackEncoder* encoder, + const SpdyHeaderBlock& header_set) { + return encoder->EncodeHeaderBlock(header_set); } static bool EncodeIncremental(HpackEncoder* encoder, @@ -82,12 +81,11 @@ class HpackEncoderPeer { std::string* output) { std::unique_ptr encoderator = encoder->EncodeHeaderSet(header_set); - std::string output_buffer; http2::test::Http2Random random; - encoderator->Next(random.UniformInRange(0, 16), &output_buffer); + std::string output_buffer = encoderator->Next(random.UniformInRange(0, 16)); while (encoderator->HasNext()) { - std::string second_buffer; - encoderator->Next(random.UniformInRange(0, 16), &second_buffer); + std::string second_buffer = + encoderator->Next(random.UniformInRange(0, 16)); output_buffer.append(second_buffer); } *output = std::move(output_buffer); @@ -99,12 +97,11 @@ class HpackEncoderPeer { std::string* output) { std::unique_ptr encoderator = encoder->EncodeRepresentations(representations); - std::string output_buffer; http2::test::Http2Random random; - encoderator->Next(random.UniformInRange(0, 16), &output_buffer); + std::string output_buffer = encoderator->Next(random.UniformInRange(0, 16)); while (encoderator->HasNext()) { - std::string second_buffer; - encoderator->Next(random.UniformInRange(0, 16), &second_buffer); + std::string second_buffer = + encoderator->Next(random.UniformInRange(0, 16)); output_buffer.append(second_buffer); } *output = std::move(output_buffer); @@ -155,6 +152,7 @@ class HpackEncoderTest : public QuicheTestWithParam { // No further insertions may occur without evictions. peer_.table()->SetMaxSize(peer_.table()->size()); + QUICHE_CHECK_EQ(kInitialDynamicTableSize, peer_.table()->size()); } void SaveHeaders(absl::string_view name, absl::string_view value) { @@ -218,12 +216,12 @@ class HpackEncoderTest : public QuicheTestWithParam { return r; } void CompareWithExpectedEncoding(const SpdyHeaderBlock& header_set) { - std::string expected_out, actual_out; - expected_.TakeString(&expected_out); + std::string actual_out; + std::string expected_out = expected_.TakeString(); switch (strategy_) { case kDefault: - EXPECT_TRUE(test::HpackEncoderPeer::EncodeHeaderSet( - &encoder_, header_set, &actual_out)); + actual_out = + test::HpackEncoderPeer::EncodeHeaderBlock(&encoder_, header_set); break; case kIncremental: EXPECT_TRUE(test::HpackEncoderPeer::EncodeIncremental( @@ -237,8 +235,8 @@ class HpackEncoderTest : public QuicheTestWithParam { EXPECT_EQ(expected_out, actual_out); } void CompareWithExpectedEncoding(const Representations& representations) { - std::string expected_out, actual_out; - expected_.TakeString(&expected_out); + std::string actual_out; + std::string expected_out = expected_.TakeString(); EXPECT_TRUE(test::HpackEncoderPeer::EncodeRepresentations( &encoder_, representations, &actual_out)); EXPECT_EQ(expected_out, actual_out); @@ -254,6 +252,9 @@ class HpackEncoderTest : public QuicheTestWithParam { HpackEncoder encoder_; test::HpackEncoderPeer peer_; + // Calculated based on the names and values inserted in SetUp(), above. + const size_t kInitialDynamicTableSize = 4 * (10 + 32); + const HpackEntry* static_; const HpackEntry* key_1_; const HpackEntry* key_2_; @@ -280,6 +281,7 @@ INSTANTIATE_TEST_SUITE_P(HpackEncoderTests, ::testing::Values(kDefault)); TEST_P(HpackEncoderTestWithDefaultStrategy, EncodeRepresentations) { + EXPECT_EQ(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); encoder_.SetHeaderListener( [this](absl::string_view name, absl::string_view value) { this->SaveHeaders(name, value); @@ -308,6 +310,30 @@ TEST_P(HpackEncoderTestWithDefaultStrategy, EncodeRepresentations) { Pair("accept", "text/html, text/plain,application/xml"), Pair("cookie", "val4"), Pair("withnul", absl::string_view("one\0two", 7)))); + // Insertions and evictions have happened over the course of the test. + EXPECT_GE(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); +} + +TEST_P(HpackEncoderTestWithDefaultStrategy, DynamicTableGrows) { + EXPECT_EQ(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); + peer_.table()->SetMaxSize(4096); + encoder_.SetHeaderListener( + [this](absl::string_view name, absl::string_view value) { + this->SaveHeaders(name, value); + }); + const std::vector> + header_list = {{"cookie", "val1; val2;val3"}, + {":path", "/home"}, + {"accept", "text/html, text/plain,application/xml"}, + {"cookie", "val4"}, + {"withnul", absl::string_view("one\0two", 7)}}; + std::string out; + EXPECT_TRUE(test::HpackEncoderPeer::EncodeRepresentations(&encoder_, + header_list, &out)); + + EXPECT_FALSE(out.empty()); + // Insertions have happened over the course of the test. + EXPECT_GT(encoder_.GetDynamicTableSize(), kInitialDynamicTableSize); } INSTANTIATE_TEST_SUITE_P(HpackEncoderTests, @@ -431,8 +457,8 @@ TEST_P(HpackEncoderTest, StringsDynamicallySelectHuffmanCoding) { expected_.AppendUint32(6); expected_.AppendBytes("@@@@@@"); - std::string expected_out, actual_out; - expected_.TakeString(&expected_out); + std::string actual_out; + std::string expected_out = expected_.TakeString(); peer_.TakeString(&actual_out); EXPECT_EQ(expected_out, actual_out); } @@ -479,6 +505,7 @@ TEST_P(HpackEncoderTest, EncodingWithoutCompression) { Pair("hello", "aloha"), Pair("multivalue", "value1, value2"))); } + EXPECT_EQ(kInitialDynamicTableSize, encoder_.GetDynamicTableSize()); } TEST_P(HpackEncoderTest, MultipleEncodingPasses) { diff --git a/gquiche/spdy/core/hpack/hpack_entry.cc b/gquiche/spdy/core/hpack/hpack_entry.cc index fb8747e1..a2b61637 100644 --- a/gquiche/spdy/core/hpack/hpack_entry.cc +++ b/gquiche/spdy/core/hpack/hpack_entry.cc @@ -5,7 +5,6 @@ #include "gquiche/spdy/core/hpack/hpack_entry.h" #include "absl/strings/str_cat.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" namespace spdy { @@ -24,8 +23,4 @@ std::string HpackEntry::GetDebugString() const { return absl::StrCat("{ name: \"", name_, "\", value: \"", value_, "\" }"); } -size_t HpackEntry::EstimateMemoryUsage() const { - return SpdyEstimateMemoryUsage(name_) + SpdyEstimateMemoryUsage(value_); -} - } // namespace spdy diff --git a/gquiche/spdy/core/hpack/hpack_entry.h b/gquiche/spdy/core/hpack/hpack_entry.h index 3c36e62c..6b956c80 100644 --- a/gquiche/spdy/core/hpack/hpack_entry.h +++ b/gquiche/spdy/core/hpack/hpack_entry.h @@ -71,9 +71,6 @@ class QUICHE_EXPORT_PRIVATE HpackEntry { std::string GetDebugString() const; - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - private: std::string name_; std::string value_; diff --git a/gquiche/spdy/core/hpack/hpack_header_table.cc b/gquiche/spdy/core/hpack/hpack_header_table.cc index 522739d1..e5d3143d 100644 --- a/gquiche/spdy/core/hpack/hpack_header_table.cc +++ b/gquiche/spdy/core/hpack/hpack_header_table.cc @@ -9,7 +9,6 @@ #include "gquiche/common/platform/api/quiche_logging.h" #include "gquiche/spdy/core/hpack/hpack_constants.h" #include "gquiche/spdy/core/hpack/hpack_static_table.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" namespace spdy { @@ -186,10 +185,4 @@ const HpackEntry* HpackHeaderTable::TryAddEntry(absl::string_view name, return &dynamic_entries_.front(); } -size_t HpackHeaderTable::EstimateMemoryUsage() const { - return SpdyEstimateMemoryUsage(dynamic_entries_) + - SpdyEstimateMemoryUsage(dynamic_index_) + - SpdyEstimateMemoryUsage(dynamic_name_index_); -} - } // namespace spdy diff --git a/gquiche/spdy/core/hpack/hpack_header_table.h b/gquiche/spdy/core/hpack/hpack_header_table.h index 645ba6ce..74d699b8 100644 --- a/gquiche/spdy/core/hpack/hpack_header_table.h +++ b/gquiche/spdy/core/hpack/hpack_header_table.h @@ -43,11 +43,6 @@ class QUICHE_EXPORT_PRIVATE HpackHeaderTable { // HpackHeaderTable takes advantage of the deque property that references // remain valid, so long as insertions & deletions are at the head & tail. // This precludes the use of base::circular_deque. - // - // If this changes (we want to change to circular_deque or we start to drop - // entries from the middle of the table), this should to be a std::list, in - // which case |*_index_| can be trivially extended to map to list iterators. - // // TODO(b/182349990): Change to a more memory efficient container. using DynamicEntryTable = std::deque; @@ -103,9 +98,6 @@ class QUICHE_EXPORT_PRIVATE HpackHeaderTable { const HpackEntry* TryAddEntry(absl::string_view name, absl::string_view value); - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - private: // Returns number of evictions required to enter |name| & |value|. size_t EvictionCountForEntry(absl::string_view name, diff --git a/gquiche/spdy/core/hpack/hpack_output_stream.cc b/gquiche/spdy/core/hpack/hpack_output_stream.cc index aac9244b..954ca209 100644 --- a/gquiche/spdy/core/hpack/hpack_output_stream.cc +++ b/gquiche/spdy/core/hpack/hpack_output_stream.cc @@ -7,7 +7,6 @@ #include #include "gquiche/common/platform/api/quiche_logging.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" namespace spdy { @@ -69,36 +68,33 @@ std::string* HpackOutputStream::MutableString() { return &buffer_; } -void HpackOutputStream::TakeString(std::string* output) { +std::string HpackOutputStream::TakeString() { // This must hold, since all public functions cause the buffer to // end on a byte boundary. QUICHE_DCHECK_EQ(bit_offset_, 0u); - buffer_.swap(*output); - buffer_.clear(); + std::string out = std::move(buffer_); + buffer_ = {}; bit_offset_ = 0; + return out; } -void HpackOutputStream::BoundedTakeString(size_t max_size, - std::string* output) { +std::string HpackOutputStream::BoundedTakeString(size_t max_size) { if (buffer_.size() > max_size) { // Save off overflow bytes to temporary string (causes a copy). - std::string overflow(buffer_.data() + max_size, buffer_.size() - max_size); + std::string overflow = buffer_.substr(max_size); // Resize buffer down to the given limit. buffer_.resize(max_size); // Give buffer to output string. - *output = std::move(buffer_); + std::string out = std::move(buffer_); // Reset to contain overflow. buffer_ = std::move(overflow); + return out; } else { - TakeString(output); + return TakeString(); } } -size_t HpackOutputStream::EstimateMemoryUsage() const { - return SpdyEstimateMemoryUsage(buffer_); -} - } // namespace spdy diff --git a/gquiche/spdy/core/hpack/hpack_output_stream.h b/gquiche/spdy/core/hpack/hpack_output_stream.h index 4afec362..02314d2b 100644 --- a/gquiche/spdy/core/hpack/hpack_output_stream.h +++ b/gquiche/spdy/core/hpack/hpack_output_stream.h @@ -51,19 +51,16 @@ class QUICHE_EXPORT_PRIVATE HpackOutputStream { // Return pointer to internal buffer. |bit_offset_| needs to be zero. std::string* MutableString(); - // Swaps the internal buffer with |output|, then resets state. - void TakeString(std::string* output); + // Returns the internal buffer as a string, then resets state. + std::string TakeString(); - // Gives up to |max_size| bytes of the internal buffer to |output|. Resets + // Returns up to |max_size| bytes of the internal buffer. Resets // internal state with the overflow. - void BoundedTakeString(size_t max_size, std::string* output); + std::string BoundedTakeString(size_t max_size); // Size in bytes of stream's internal buffer. size_t size() const { return buffer_.size(); } - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - private: // The internal bit buffer. std::string buffer_; diff --git a/gquiche/spdy/core/hpack/hpack_output_stream_test.cc b/gquiche/spdy/core/hpack/hpack_output_stream_test.cc index 3e61584f..99df02f3 100644 --- a/gquiche/spdy/core/hpack/hpack_output_stream_test.cc +++ b/gquiche/spdy/core/hpack/hpack_output_stream_test.cc @@ -37,8 +37,7 @@ TEST(HpackOutputStreamTest, AppendBits) { output_stream.AppendBits(0x0, 7); - std::string str; - output_stream.TakeString(&str); + std::string str = output_stream.TakeString(); EXPECT_EQ(expected_str, str); } @@ -50,8 +49,7 @@ std::string EncodeUint32(uint8_t N, uint32_t I) { output_stream.AppendBits(0x00, 8 - N); } output_stream.AppendUint32(I); - std::string str; - output_stream.TakeString(&str); + std::string str = output_stream.TakeString(); return str; } @@ -236,8 +234,7 @@ TEST(HpackOutputStreamTest, AppendUint32PreservesUpperBits) { HpackOutputStream output_stream; output_stream.AppendBits(0x7f, 7); output_stream.AppendUint32(0x01); - std::string str; - output_stream.TakeString(&str); + std::string str = output_stream.TakeString(); EXPECT_EQ(std::string("\xff\x00", 2), str); } @@ -247,8 +244,7 @@ TEST(HpackOutputStreamTest, AppendBytes) { output_stream.AppendBytes("buffer1"); output_stream.AppendBytes("buffer2"); - std::string str; - output_stream.TakeString(&str); + std::string str = output_stream.TakeString(); EXPECT_EQ("buffer1buffer2", str); } @@ -258,16 +254,15 @@ TEST(HpackOutputStreamTest, BoundedTakeString) { output_stream.AppendBytes("buffer12"); output_stream.AppendBytes("buffer456"); - std::string str; - output_stream.BoundedTakeString(9, &str); + std::string str = output_stream.BoundedTakeString(9); EXPECT_EQ("buffer12b", str); output_stream.AppendBits(0x7f, 7); output_stream.AppendUint32(0x11); - output_stream.BoundedTakeString(9, &str); + str = output_stream.BoundedTakeString(9); EXPECT_EQ("uffer456\xff", str); - output_stream.BoundedTakeString(9, &str); + str = output_stream.BoundedTakeString(9); EXPECT_EQ("\x10", str); } @@ -280,8 +275,7 @@ TEST(HpackOutputStreamTest, MutableString) { output_stream.AppendBytes("foo"); output_stream.MutableString()->append("bar"); - std::string str; - output_stream.TakeString(&str); + std::string str = output_stream.TakeString(); EXPECT_EQ("12foobar", str); } diff --git a/gquiche/spdy/core/hpack/hpack_round_trip_test.cc b/gquiche/spdy/core/hpack/hpack_round_trip_test.cc index 1de5e190..126be548 100644 --- a/gquiche/spdy/core/hpack/hpack_round_trip_test.cc +++ b/gquiche/spdy/core/hpack/hpack_round_trip_test.cc @@ -31,8 +31,7 @@ class HpackRoundTripTest : public QuicheTestWithParam { } bool RoundTrip(const SpdyHeaderBlock& header_set) { - std::string encoded; - encoder_.EncodeHeaderSet(header_set, &encoded); + std::string encoded = encoder_.EncodeHeaderBlock(header_set); bool success = true; if (GetParam() == ALL_INPUT) { @@ -58,7 +57,7 @@ class HpackRoundTripTest : public QuicheTestWithParam { } if (success) { - success = decoder_.HandleControlFrameHeadersComplete(nullptr); + success = decoder_.HandleControlFrameHeadersComplete(); } EXPECT_EQ(header_set, decoder_.decoded_block()); diff --git a/gquiche/spdy/core/hpack/hpack_static_table.cc b/gquiche/spdy/core/hpack/hpack_static_table.cc index ae69e90e..afc8a0b9 100644 --- a/gquiche/spdy/core/hpack/hpack_static_table.cc +++ b/gquiche/spdy/core/hpack/hpack_static_table.cc @@ -8,7 +8,6 @@ #include "gquiche/common/platform/api/quiche_logging.h" #include "gquiche/spdy/core/hpack/hpack_constants.h" #include "gquiche/spdy/core/hpack/hpack_entry.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" namespace spdy { @@ -48,10 +47,4 @@ bool HpackStaticTable::IsInitialized() const { return !static_entries_.empty(); } -size_t HpackStaticTable::EstimateMemoryUsage() const { - return SpdyEstimateMemoryUsage(static_entries_) + - SpdyEstimateMemoryUsage(static_index_) + - SpdyEstimateMemoryUsage(static_name_index_); -} - } // namespace spdy diff --git a/gquiche/spdy/core/hpack/hpack_static_table.h b/gquiche/spdy/core/hpack/hpack_static_table.h index b9530973..8fb35f69 100644 --- a/gquiche/spdy/core/hpack/hpack_static_table.h +++ b/gquiche/spdy/core/hpack/hpack_static_table.h @@ -43,9 +43,6 @@ class QUICHE_EXPORT_PRIVATE HpackStaticTable { return static_name_index_; } - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - private: HpackHeaderTable::StaticEntryTable static_entries_; // The following two members have string_views that point to strings stored in diff --git a/gquiche/spdy/core/http2_frame_decoder_adapter.cc b/gquiche/spdy/core/http2_frame_decoder_adapter.cc index 21876176..0c9b4699 100644 --- a/gquiche/spdy/core/http2_frame_decoder_adapter.cc +++ b/gquiche/spdy/core/http2_frame_decoder_adapter.cc @@ -30,8 +30,6 @@ #include "gquiche/spdy/core/spdy_header_block.h" #include "gquiche/spdy/core/spdy_headers_handler_interface.h" #include "gquiche/spdy/core/spdy_protocol.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" -#include "gquiche/spdy/platform/api/spdy_string_utils.h" using ::spdy::ExtensionVisitorInterface; using ::spdy::HpackDecoderAdapter; @@ -40,7 +38,6 @@ using ::spdy::ParseErrorCode; using ::spdy::ParseFrameType; using ::spdy::SpdyAltSvcWireFormat; using ::spdy::SpdyErrorCode; -using ::spdy::SpdyEstimateMemoryUsage; using ::spdy::SpdyFramerDebugVisitorInterface; using ::spdy::SpdyFramerVisitorInterface; using ::spdy::SpdyFrameType; @@ -191,24 +188,12 @@ const char* Http2DecoderAdapter::SpdyFramerErrorToString( return "INVALID_CONTROL_FRAME"; case SPDY_CONTROL_PAYLOAD_TOO_LARGE: return "CONTROL_PAYLOAD_TOO_LARGE"; - case SPDY_ZLIB_INIT_FAILURE: - return "ZLIB_INIT_FAILURE"; - case SPDY_UNSUPPORTED_VERSION: - return "UNSUPPORTED_VERSION"; case SPDY_DECOMPRESS_FAILURE: return "DECOMPRESS_FAILURE"; - case SPDY_COMPRESS_FAILURE: - return "COMPRESS_FAILURE"; - case SPDY_GOAWAY_FRAME_CORRUPT: - return "GOAWAY_FRAME_CORRUPT"; - case SPDY_RST_STREAM_FRAME_CORRUPT: - return "RST_STREAM_FRAME_CORRUPT"; case SPDY_INVALID_PADDING: return "INVALID_PADDING"; case SPDY_INVALID_DATA_FRAME_FLAGS: return "INVALID_DATA_FRAME_FLAGS"; - case SPDY_INVALID_CONTROL_FRAME_FLAGS: - return "INVALID_CONTROL_FRAME_FLAGS"; case SPDY_UNEXPECTED_FRAME: return "UNEXPECTED_FRAME"; case SPDY_INTERNAL_FRAMER_ERROR: @@ -249,6 +234,8 @@ const char* Http2DecoderAdapter::SpdyFramerErrorToString( return "HPACK_FRAGMENT_TOO_LONG"; case SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT: return "HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT"; + case SPDY_STOP_PROCESSING: + return "STOP_PROCESSING"; case LAST_ERROR: return "UNKNOWN_ERROR"; } @@ -320,11 +307,9 @@ bool Http2DecoderAdapter::probable_http_response() const { return latched_probable_http_response_; } -size_t Http2DecoderAdapter::EstimateMemoryUsage() const { - // Skip |frame_decoder_|, |frame_header_| and |hpack_first_frame_header_| as - // they don't allocate. - return SpdyEstimateMemoryUsage(alt_svc_origin_) + - SpdyEstimateMemoryUsage(alt_svc_value_); +void Http2DecoderAdapter::StopProcessing() { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_STOP_PROCESSING, + "Ignoring further events on this connection."); } // =========================================================================== @@ -784,12 +769,12 @@ void Http2DecoderAdapter::OnFrameSizeError(const Http2FrameHeader& header) { QUICHE_DVLOG(1) << "OnFrameSizeError: " << header; size_t recv_limit = recv_frame_size_limit_; if (header.payload_length > recv_limit) { - SetSpdyErrorAndNotify(SpdyFramerError::SPDY_OVERSIZED_PAYLOAD, ""); - return; - } - if (header.type != Http2FrameType::DATA && - header.payload_length > recv_limit) { - SetSpdyErrorAndNotify(SpdyFramerError::SPDY_CONTROL_PAYLOAD_TOO_LARGE, ""); + if (header.type == Http2FrameType::DATA) { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_OVERSIZED_PAYLOAD, ""); + } else { + SetSpdyErrorAndNotify(SpdyFramerError::SPDY_CONTROL_PAYLOAD_TOO_LARGE, + ""); + } return; } switch (header.type) { @@ -1107,7 +1092,7 @@ void Http2DecoderAdapter::CommonHpackFragmentEnd() { << frame_header(); has_expected_frame_type_ = false; auto* decoder = GetHpackDecoder(); - if (decoder->HandleControlFrameHeadersComplete(nullptr)) { + if (decoder->HandleControlFrameHeadersComplete()) { visitor()->OnHeaderFrameEnd(stream_id()); } else { SetSpdyErrorAndNotify( diff --git a/gquiche/spdy/core/http2_frame_decoder_adapter.h b/gquiche/spdy/core/http2_frame_decoder_adapter.h index 85e21748..f8feb21a 100644 --- a/gquiche/spdy/core/http2_frame_decoder_adapter.h +++ b/gquiche/spdy/core/http2_frame_decoder_adapter.h @@ -62,15 +62,9 @@ class QUICHE_EXPORT_PRIVATE Http2DecoderAdapter SPDY_INVALID_STREAM_ID, // Stream ID is invalid SPDY_INVALID_CONTROL_FRAME, // Control frame is mal-formatted. SPDY_CONTROL_PAYLOAD_TOO_LARGE, // Control frame payload was too large. - SPDY_ZLIB_INIT_FAILURE, // The Zlib library could not initialize. - SPDY_UNSUPPORTED_VERSION, // Control frame has unsupported version. SPDY_DECOMPRESS_FAILURE, // There was an error decompressing. - SPDY_COMPRESS_FAILURE, // There was an error compressing. - SPDY_GOAWAY_FRAME_CORRUPT, // GOAWAY frame could not be parsed. - SPDY_RST_STREAM_FRAME_CORRUPT, // RST_STREAM frame could not be parsed. SPDY_INVALID_PADDING, // HEADERS or DATA frame padding invalid SPDY_INVALID_DATA_FRAME_FLAGS, // Data frame has invalid flags. - SPDY_INVALID_CONTROL_FRAME_FLAGS, // Control frame has invalid flags. SPDY_UNEXPECTED_FRAME, // Frame received out of order. SPDY_INTERNAL_FRAMER_ERROR, // SpdyFramer was used incorrectly. SPDY_INVALID_CONTROL_FRAME_SIZE, // Control frame not sized to spec @@ -95,6 +89,10 @@ class QUICHE_EXPORT_PRIVATE Http2DecoderAdapter SPDY_HPACK_FRAGMENT_TOO_LONG, SPDY_HPACK_COMPRESSED_HEADER_SIZE_EXCEEDS_LIMIT, + // Set if the visitor no longer wishes to receive events for this + // connection. + SPDY_STOP_PROCESSING, + LAST_ERROR, // Must be the last entry in the enum. }; @@ -153,13 +151,17 @@ class QUICHE_EXPORT_PRIVATE Http2DecoderAdapter // has responded with an HTTP/1.1 (or earlier) response. bool probable_http_response() const; - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - spdy::HpackDecoderAdapter* GetHpackDecoder(); + const spdy::HpackDecoderAdapter* GetHpackDecoder() const { + return hpack_decoder_.get(); + } bool HasError() const; + // A visitor may call this method to indicate it no longer wishes to receive + // events for this connection. + void StopProcessing(); + private: bool OnFrameHeader(const Http2FrameHeader& header) override; void OnDataStart(const Http2FrameHeader& header) override; diff --git a/gquiche/spdy/core/http2_header_block_hpack_listener.h b/gquiche/spdy/core/http2_header_block_hpack_listener.h new file mode 100644 index 00000000..47b6825f --- /dev/null +++ b/gquiche/spdy/core/http2_header_block_hpack_listener.h @@ -0,0 +1,47 @@ +#ifndef QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_HPACK_LISTENER_H_ +#define QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_HPACK_LISTENER_H_ + +#include "absl/strings/string_view.h" +#include "gquiche/http2/hpack/decoder/hpack_decoder_listener.h" +#include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/spdy/core/spdy_header_block.h" + +namespace spdy { + +// This class simply gathers the key-value pairs emitted by an HpackDecoder in +// a SpdyHeaderBlock. +class Http2HeaderBlockHpackListener : public http2::HpackDecoderListener { + public: + Http2HeaderBlockHpackListener() {} + + void OnHeaderListStart() override { + header_block_.clear(); + hpack_error_ = false; + } + + void OnHeader(const std::string& name, const std::string& value) override { + header_block_.AppendValueOrAddHeader(name, value); + } + + void OnHeaderListEnd() override {} + + void OnHeaderErrorDetected(absl::string_view error_message) override { + QUICHE_VLOG(1) << error_message; + hpack_error_ = true; + } + + SpdyHeaderBlock release_header_block() { + SpdyHeaderBlock block = std::move(header_block_); + header_block_ = {}; + return block; + } + bool hpack_error() const { return hpack_error_; } + + private: + SpdyHeaderBlock header_block_; + bool hpack_error_ = false; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_HTTP2_HEADER_BLOCK_HPACK_LISTENER_H_ diff --git a/gquiche/spdy/core/metadata_extension.cc b/gquiche/spdy/core/metadata_extension.cc new file mode 100644 index 00000000..9559edcf --- /dev/null +++ b/gquiche/spdy/core/metadata_extension.cc @@ -0,0 +1,195 @@ +#include "gquiche/spdy/core/metadata_extension.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "gquiche/http2/decoder/decode_buffer.h" +#include "gquiche/http2/hpack/decoder/hpack_decoder.h" +#include "gquiche/common/platform/api/quiche_bug_tracker.h" +#include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/spdy/core/hpack/hpack_encoder.h" +#include "gquiche/spdy/core/http2_header_block_hpack_listener.h" + +namespace spdy { + +// Non-standard constants related to METADATA frames. +const SpdySettingsId MetadataVisitor::kMetadataExtensionId = 0x4d44; +const uint8_t MetadataVisitor::kMetadataFrameType = 0x4d; +const uint8_t MetadataVisitor::kEndMetadataFlag = 0x4; + +namespace { + +const size_t kMaxMetadataBlockSize = 1 << 20; // 1 MB + +// This class uses an HpackEncoder to serialize a METADATA block as a series of +// METADATA frames. +class MetadataFrameSequence : public MetadataSerializer::FrameSequence { + public: + MetadataFrameSequence(SpdyStreamId stream_id, spdy::SpdyHeaderBlock payload) + : stream_id_(stream_id), payload_(std::move(payload)) { + // Metadata should not use HPACK compression. + encoder_.DisableCompression(); + HpackEncoder::Representations r; + for (const auto& kv_pair : payload_) { + r.push_back(kv_pair); + } + progressive_encoder_ = encoder_.EncodeRepresentations(r); + } + + // Copies are not allowed. + MetadataFrameSequence(const MetadataFrameSequence& other) = delete; + MetadataFrameSequence& operator=(const MetadataFrameSequence& other) = delete; + + std::unique_ptr Next() override; + + private: + SpdyStreamId stream_id_; + SpdyHeaderBlock payload_; + HpackEncoder encoder_; + std::unique_ptr progressive_encoder_; +}; + +std::unique_ptr MetadataFrameSequence::Next() { + if (!progressive_encoder_->HasNext()) { + return nullptr; + } + // METADATA frames obey the HTTP/2 maximum frame size. + std::string payload = + progressive_encoder_->Next(spdy::kHttp2DefaultFramePayloadLimit); + const bool end_metadata = (!progressive_encoder_->HasNext()); + const uint8_t flags = end_metadata ? MetadataVisitor::kEndMetadataFlag : 0; + return absl::make_unique( + stream_id_, MetadataVisitor::kMetadataFrameType, flags, + std::move(payload)); +} + +} // anonymous namespace + +struct MetadataVisitor::MetadataPayloadState { + MetadataPayloadState(size_t remaining, bool end) + : bytes_remaining(remaining), end_metadata(end) {} + std::list buffer; + size_t bytes_remaining; + bool end_metadata; +}; + +MetadataVisitor::MetadataVisitor(OnCompletePayload on_payload, + OnMetadataSupport on_support) + : on_payload_(std::move(on_payload)), + on_support_(std::move(on_support)), + peer_supports_metadata_(MetadataSupportState::UNSPECIFIED) {} + +MetadataVisitor::~MetadataVisitor() {} + +void MetadataVisitor::OnSetting(SpdySettingsId id, uint32_t value) { + QUICHE_VLOG(1) << "MetadataVisitor::OnSetting(" << id << ", " << value << ")"; + if (id == kMetadataExtensionId) { + if (value == 0) { + const MetadataSupportState previous_state = peer_supports_metadata_; + peer_supports_metadata_ = MetadataSupportState::NOT_SUPPORTED; + if (previous_state == MetadataSupportState::UNSPECIFIED || + previous_state == MetadataSupportState::SUPPORTED) { + on_support_(false); + } + } else if (value == 1) { + const MetadataSupportState previous_state = peer_supports_metadata_; + peer_supports_metadata_ = MetadataSupportState::SUPPORTED; + if (previous_state == MetadataSupportState::UNSPECIFIED || + previous_state == MetadataSupportState::NOT_SUPPORTED) { + on_support_(true); + } + } else { + LOG_EVERY_N_SEC(WARNING, 1) + << "Unrecognized value for setting " << id << ": " << value; + } + } +} + +bool MetadataVisitor::OnFrameHeader(SpdyStreamId stream_id, size_t length, + uint8_t type, uint8_t flags) { + QUICHE_VLOG(1) << "OnFrameHeader(stream_id=" << stream_id + << ", length=" << length << ", type=" << static_cast(type) + << ", flags=" << static_cast(flags); + // TODO(birenroy): Consider disabling METADATA handling until our setting + // advertising METADATA support has been acked. + if (type != kMetadataFrameType) { + return false; + } + auto it = metadata_map_.find(stream_id); + if (it == metadata_map_.end()) { + auto state = absl::make_unique( + length, flags & kEndMetadataFlag); + auto result = metadata_map_.insert(std::make_pair(stream_id, + std::move(state))); + QUICHE_BUG_IF(bug_if_2781_1, !result.second) << "Map insertion failed."; + it = result.first; + } else { + QUICHE_BUG_IF(bug_22051_1, it->second->end_metadata) + << "Inconsistent metadata payload state!"; + QUICHE_BUG_IF(bug_if_2781_2, it->second->bytes_remaining > 0) + << "Incomplete metadata block!"; + } + + if (it->second == nullptr) { + QUICHE_BUG(bug_2781_3) << "Null metadata payload state!"; + return false; + } + current_stream_ = stream_id; + it->second->bytes_remaining = length; + it->second->end_metadata = (flags & kEndMetadataFlag); + return true; +} + +void MetadataVisitor::OnFramePayload(const char* data, size_t len) { + QUICHE_VLOG(1) << "OnFramePayload(stream_id=" << current_stream_ + << ", len=" << len << ")"; + auto it = metadata_map_.find(current_stream_); + if (it == metadata_map_.end() || it->second == nullptr) { + QUICHE_BUG(bug_2781_4) << "Invalid order of operations on MetadataVisitor."; + } else { + MetadataPayloadState* state = it->second.get(); // For readability. + state->buffer.push_back(std::string(data, len)); + if (len < state->bytes_remaining) { + state->bytes_remaining -= len; + } else { + QUICHE_BUG_IF(bug_22051_2, len > state->bytes_remaining) + << "Metadata payload overflow! len: " << len + << " bytes_remaining: " << state->bytes_remaining; + state->bytes_remaining = 0; + if (state->end_metadata) { + // The whole process of decoding the HPACK-encoded metadata block, + // below, is more cumbersome than it ought to be. + spdy::Http2HeaderBlockHpackListener listener; + http2::HpackDecoder decoder(&listener, kMaxMetadataBlockSize); + + // If any operations fail, the decode process should be aborted. + bool success = decoder.StartDecodingBlock(); + for (const std::string& slice : state->buffer) { + if (!success) { + break; + } + http2::DecodeBuffer buffer(slice.data(), slice.size()); + success = success && decoder.DecodeFragment(&buffer); + } + success = + success && decoder.EndDecodingBlock() && !listener.hpack_error(); + if (success) { + on_payload_(current_stream_, listener.release_header_block()); + } + // TODO(birenroy): add varz counting metadata decode successes/failures. + metadata_map_.erase(it); + } + } + } +} + +std::unique_ptr +MetadataSerializer::FrameSequenceForPayload(SpdyStreamId stream_id, + MetadataPayload payload) { + return absl::make_unique(stream_id, + std::move(payload)); +} + +} // namespace spdy diff --git a/gquiche/spdy/core/metadata_extension.h b/gquiche/spdy/core/metadata_extension.h new file mode 100644 index 00000000..453673e7 --- /dev/null +++ b/gquiche/spdy/core/metadata_extension.h @@ -0,0 +1,116 @@ +#ifndef QUICHE_SPDY_CORE_METADATA_EXTENSION_H_ +#define QUICHE_SPDY_CORE_METADATA_EXTENSION_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "gquiche/spdy/core/http2_frame_decoder_adapter.h" +#include "gquiche/spdy/core/spdy_header_block.h" +#include "gquiche/spdy/core/spdy_protocol.h" +#include "gquiche/spdy/core/zero_copy_output_buffer.h" + +namespace spdy { + +// An implementation of the ExtensionVisitorInterface that can parse +// METADATA frames. METADATA is a non-standard HTTP/2 extension developed and +// used internally at Google. A peer advertises support for METADATA by sending +// a setting with a setting ID of kMetadataExtensionId and a value of 1. +// +// Metadata is represented as a HPACK header block with literal encoding. +class MetadataVisitor : public spdy::ExtensionVisitorInterface { + public: + using MetadataPayload = spdy::SpdyHeaderBlock; + + static_assert(!std::is_copy_constructible::value, + "MetadataPayload should be a move-only type!"); + + using OnMetadataSupport = std::function; + using OnCompletePayload = + std::function; + + // The HTTP/2 SETTINGS ID that is used to indicate support for METADATA + // frames. + static const spdy::SpdySettingsId kMetadataExtensionId; + + // The 8-bit frame type code for a METADATA frame. + static const uint8_t kMetadataFrameType; + + // The flag that indicates the end of a logical metadata block. Due to frame + // size limits, a single metadata block may be emitted as several HTTP/2 + // frames. + static const uint8_t kEndMetadataFlag; + + // |on_payload| is invoked whenever a complete metadata payload is received. + // |on_support| is invoked whenever the peer's advertised support for metadata + // changes. + MetadataVisitor(OnCompletePayload on_payload, OnMetadataSupport on_support); + ~MetadataVisitor() override; + + MetadataVisitor(const MetadataVisitor&) = delete; + MetadataVisitor& operator=(const MetadataVisitor&) = delete; + + // Interprets the non-standard setting indicating support for METADATA. + void OnSetting(spdy::SpdySettingsId id, uint32_t value) override; + + // Returns true iff |type| indicates a METADATA frame. + bool OnFrameHeader(spdy::SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags) override; + + // Consumes a METADATA frame payload. Invokes the registered callback when a + // complete payload has been received. + void OnFramePayload(const char* data, size_t len) override; + + // Returns true if the peer has advertised support for METADATA via the + // appropriate setting. + bool PeerSupportsMetadata() const { + return peer_supports_metadata_ == MetadataSupportState::SUPPORTED; + } + + private: + enum class MetadataSupportState : uint8_t { + UNSPECIFIED, + SUPPORTED, + NOT_SUPPORTED, + }; + + struct MetadataPayloadState; + + using StreamMetadataMap = + absl::flat_hash_map>; + + OnCompletePayload on_payload_; + OnMetadataSupport on_support_; + StreamMetadataMap metadata_map_; + spdy::SpdyStreamId current_stream_; + MetadataSupportState peer_supports_metadata_; +}; + +// A class that serializes metadata blocks as sequences of frames. +class MetadataSerializer { + public: + using MetadataPayload = spdy::SpdyHeaderBlock; + + class FrameSequence { + public: + virtual ~FrameSequence() {} + + // Returns nullptr once the sequence has been exhausted. + virtual std::unique_ptr Next() = 0; + }; + + MetadataSerializer() {} + + MetadataSerializer(const MetadataSerializer&) = delete; + MetadataSerializer& operator=(const MetadataSerializer&) = delete; + + // Returns nullptr on failure. + std::unique_ptr FrameSequenceForPayload( + spdy::SpdyStreamId stream_id, MetadataPayload payload); +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_METADATA_EXTENSION_H_ diff --git a/gquiche/spdy/core/metadata_extension_test.cc b/gquiche/spdy/core/metadata_extension_test.cc new file mode 100644 index 00000000..d068c8d3 --- /dev/null +++ b/gquiche/spdy/core/metadata_extension_test.cc @@ -0,0 +1,229 @@ +#include "gquiche/spdy/core/metadata_extension.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/bind_front.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/spdy/core/array_output_buffer.h" +#include "gquiche/spdy/core/mock_spdy_framer_visitor.h" +#include "gquiche/spdy/core/spdy_framer.h" +#include "gquiche/spdy/core/spdy_header_block.h" +#include "gquiche/spdy/core/spdy_no_op_visitor.h" +#include "gquiche/spdy/core/spdy_protocol.h" + +namespace spdy { +namespace test { +namespace { + +using ::absl::bind_front; +using ::spdy::SpdyFramer; +using ::spdy::SpdyHeaderBlock; +using ::spdy::test::MockSpdyFramerVisitor; +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::IsEmpty; + +const size_t kBufferSize = 64 * 1024; +char kBuffer[kBufferSize]; + +class MetadataExtensionTest : public QuicheTest { + protected: + MetadataExtensionTest() : test_buffer_(kBuffer, kBufferSize) {} + + void SetUp() override { + extension_ = absl::make_unique( + bind_front(&MetadataExtensionTest::OnCompletePayload, this), + bind_front(&MetadataExtensionTest::OnMetadataSupport, this)); + } + + void OnCompletePayload(spdy::SpdyStreamId stream_id, + MetadataVisitor::MetadataPayload payload) { + ++received_count_; + received_payload_map_.insert(std::make_pair(stream_id, std::move(payload))); + } + + void OnMetadataSupport(bool peer_supports_metadata) { + EXPECT_EQ(peer_supports_metadata, extension_->PeerSupportsMetadata()); + received_metadata_support_.push_back(peer_supports_metadata); + } + + MetadataSerializer::MetadataPayload PayloadForData(absl::string_view data) { + SpdyHeaderBlock block; + block["example-payload"] = data; + return block; + } + + std::unique_ptr extension_; + absl::flat_hash_map + received_payload_map_; + std::vector received_metadata_support_; + size_t received_count_ = 0; + spdy::ArrayOutputBuffer test_buffer_; +}; + +// This test verifies that the MetadataVisitor is initialized to a state where +// it believes the peer does not support metadata. +TEST_F(MetadataExtensionTest, MetadataNotSupported) { + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + EXPECT_THAT(received_metadata_support_, IsEmpty()); +} + +// This test verifies that upon receiving a specific setting, the extension +// realizes that the peer supports metadata. +TEST_F(MetadataExtensionTest, MetadataSupported) { + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + // 3 is not an appropriate value for the metadata extension key. + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 3); + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 0); + EXPECT_FALSE(extension_->PeerSupportsMetadata()); + EXPECT_THAT(received_metadata_support_, ElementsAre(true, false)); +} + +TEST_F(MetadataExtensionTest, MetadataIgnoredWithoutExtension) { + const char kData[] = "some payload"; + SpdyHeaderBlock payload = PayloadForData(kData); + + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + + MetadataSerializer serializer; + auto sequence = serializer.FrameSequenceForPayload(3, std::move(payload)); + ASSERT_TRUE(sequence != nullptr); + + http2::Http2DecoderAdapter deframer; + ::testing::StrictMock visitor; + deframer.set_visitor(&visitor); + + EXPECT_CALL(visitor, + OnCommonHeader(3, _, MetadataVisitor::kMetadataFrameType, _)); + // The Return(true) should not be necessary. http://b/36023792 + EXPECT_CALL(visitor, OnUnknownFrame(3, MetadataVisitor::kMetadataFrameType)) + .WillOnce(::testing::Return(true)); + + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + auto frame = sequence->Next(); + ASSERT_TRUE(frame != nullptr); + while (frame != nullptr) { + const size_t frame_size = framer.SerializeFrame(*frame, &test_buffer_); + ASSERT_GT(frame_size, 0u); + ASSERT_FALSE(deframer.HasError()); + ASSERT_EQ(frame_size, test_buffer_.Size()); + EXPECT_EQ(frame_size, deframer.ProcessInput(kBuffer, frame_size)); + test_buffer_.Reset(); + frame = sequence->Next(); + } + EXPECT_FALSE(deframer.HasError()); + EXPECT_THAT(received_metadata_support_, ElementsAre(true)); +} + +// This test verifies that the METADATA frame emitted by a MetadataExtension +// can be parsed by another SpdyFramer with a MetadataVisitor. +TEST_F(MetadataExtensionTest, MetadataPayloadEndToEnd) { + SpdyHeaderBlock block1; + block1["foo"] = "Some metadata value."; + SpdyHeaderBlock block2; + block2["bar"] = + "The color taupe truly represents a triumph of the human spirit over " + "adversity."; + block2["baz"] = + "Or perhaps it represents abject surrender to the implacable and " + "incomprehensible forces of the universe."; + const absl::string_view binary_payload{"binary\0payload", 14}; + block2["qux"] = binary_payload; + EXPECT_EQ(binary_payload, block2["qux"]); + for (const SpdyHeaderBlock& payload_block : + {std::move(block1), std::move(block2)}) { + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + + MetadataSerializer serializer; + auto sequence = + serializer.FrameSequenceForPayload(3, payload_block.Clone()); + ASSERT_TRUE(sequence != nullptr); + + http2::Http2DecoderAdapter deframer; + ::spdy::SpdyNoOpVisitor visitor; + deframer.set_visitor(&visitor); + deframer.set_extension_visitor(extension_.get()); + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + auto frame = sequence->Next(); + ASSERT_TRUE(frame != nullptr); + while (frame != nullptr) { + const size_t frame_size = framer.SerializeFrame(*frame, &test_buffer_); + ASSERT_GT(frame_size, 0u); + ASSERT_FALSE(deframer.HasError()); + ASSERT_EQ(frame_size, test_buffer_.Size()); + EXPECT_EQ(frame_size, deframer.ProcessInput(kBuffer, frame_size)); + test_buffer_.Reset(); + frame = sequence->Next(); + } + EXPECT_EQ(1u, received_count_); + auto it = received_payload_map_.find(3); + ASSERT_TRUE(it != received_payload_map_.end()); + EXPECT_EQ(payload_block, it->second); + + received_count_ = 0; + received_payload_map_.clear(); + } +} + +// This test verifies that METADATA frames for two different streams can be +// interleaved and still successfully parsed by another SpdyFramer with a +// MetadataVisitor. +TEST_F(MetadataExtensionTest, MetadataPayloadInterleaved) { + const std::string kData1 = std::string(65 * 1024, 'a'); + const std::string kData2 = std::string(65 * 1024, 'b'); + const SpdyHeaderBlock payload1 = PayloadForData(kData1); + const SpdyHeaderBlock payload2 = PayloadForData(kData2); + + extension_->OnSetting(MetadataVisitor::kMetadataExtensionId, 1); + ASSERT_TRUE(extension_->PeerSupportsMetadata()); + + MetadataSerializer serializer; + auto sequence1 = serializer.FrameSequenceForPayload(3, payload1.Clone()); + ASSERT_TRUE(sequence1 != nullptr); + + auto sequence2 = serializer.FrameSequenceForPayload(5, payload2.Clone()); + ASSERT_TRUE(sequence2 != nullptr); + + http2::Http2DecoderAdapter deframer; + ::spdy::SpdyNoOpVisitor visitor; + deframer.set_visitor(&visitor); + deframer.set_extension_visitor(extension_.get()); + + SpdyFramer framer(SpdyFramer::ENABLE_COMPRESSION); + auto frame1 = sequence1->Next(); + ASSERT_TRUE(frame1 != nullptr); + auto frame2 = sequence2->Next(); + ASSERT_TRUE(frame2 != nullptr); + while (frame1 != nullptr || frame2 != nullptr) { + for (auto frame : {frame1.get(), frame2.get()}) { + if (frame != nullptr) { + const size_t frame_size = framer.SerializeFrame(*frame, &test_buffer_); + ASSERT_GT(frame_size, 0u); + ASSERT_FALSE(deframer.HasError()); + ASSERT_EQ(frame_size, test_buffer_.Size()); + EXPECT_EQ(frame_size, deframer.ProcessInput(kBuffer, frame_size)); + test_buffer_.Reset(); + } + } + frame1 = sequence1->Next(); + frame2 = sequence2->Next(); + } + EXPECT_EQ(2u, received_count_); + auto it = received_payload_map_.find(3); + ASSERT_TRUE(it != received_payload_map_.end()); + EXPECT_EQ(payload1, it->second); + + it = received_payload_map_.find(5); + ASSERT_TRUE(it != received_payload_map_.end()); + EXPECT_EQ(payload2, it->second); +} + +} // anonymous namespace +} // namespace test +} // namespace spdy diff --git a/gquiche/spdy/core/mock_spdy_framer_visitor.h b/gquiche/spdy/core/mock_spdy_framer_visitor.h index 2b0741d9..85e82467 100644 --- a/gquiche/spdy/core/mock_spdy_framer_visitor.h +++ b/gquiche/spdy/core/mock_spdy_framer_visitor.h @@ -19,7 +19,8 @@ namespace spdy { namespace test { -class MockSpdyFramerVisitor : public SpdyFramerVisitorInterface { +class QUICHE_NO_EXPORT MockSpdyFramerVisitor + : public SpdyFramerVisitorInterface { public: MockSpdyFramerVisitor(); ~MockSpdyFramerVisitor() override; @@ -29,6 +30,10 @@ class MockSpdyFramerVisitor : public SpdyFramerVisitorInterface { (http2::Http2DecoderAdapter::SpdyFramerError error, std::string detailed_error), (override)); + MOCK_METHOD(void, OnCommonHeader, + (SpdyStreamId stream_id, size_t length, uint8_t type, + uint8_t flags), + (override)); MOCK_METHOD(void, OnDataFrameHeader, (SpdyStreamId stream_id, size_t length, bool fin), @@ -59,10 +64,13 @@ class MockSpdyFramerVisitor : public SpdyFramerVisitorInterface { MOCK_METHOD(void, OnSetting, (SpdySettingsId id, uint32_t value), (override)); MOCK_METHOD(void, OnPing, (SpdyPingId unique_id, bool is_ack), (override)); MOCK_METHOD(void, OnSettingsEnd, (), (override)); + MOCK_METHOD(void, OnSettingsAck, (), (override)); MOCK_METHOD(void, OnGoAway, (SpdyStreamId last_accepted_stream_id, SpdyErrorCode error_code), (override)); + MOCK_METHOD(bool, OnGoAwayFrameData, (const char* goaway_data, size_t len), + (override)); MOCK_METHOD(void, OnHeaders, (SpdyStreamId stream_id, diff --git a/gquiche/spdy/core/no_op_headers_handler.h b/gquiche/spdy/core/no_op_headers_handler.h new file mode 100644 index 00000000..c5853de2 --- /dev/null +++ b/gquiche/spdy/core/no_op_headers_handler.h @@ -0,0 +1,39 @@ +#ifndef QUICHE_SPDY_CORE_NO_OP_HEADERS_HANDLER_H_ +#define QUICHE_SPDY_CORE_NO_OP_HEADERS_HANDLER_H_ + +#include "gquiche/common/platform/api/quiche_export.h" +#include "gquiche/spdy/core/header_byte_listener_interface.h" +#include "gquiche/spdy/core/spdy_headers_handler_interface.h" + +namespace spdy { + +// Drops all header data, but passes information about header bytes parsed to +// a listener. +class QUICHE_EXPORT_PRIVATE NoOpHeadersHandler + : public SpdyHeadersHandlerInterface { + public: + // Does not take ownership of listener. + explicit NoOpHeadersHandler(HeaderByteListenerInterface* listener) + : listener_(listener) {} + NoOpHeadersHandler(const NoOpHeadersHandler&) = delete; + NoOpHeadersHandler& operator=(const NoOpHeadersHandler&) = delete; + ~NoOpHeadersHandler() override {} + + // From SpdyHeadersHandlerInterface + void OnHeaderBlockStart() override {} + void OnHeader(absl::string_view /*key*/, + absl::string_view /*value*/) override {} + void OnHeaderBlockEnd(size_t uncompressed_header_bytes, + size_t /* compressed_header_bytes */) override { + if (listener_ != nullptr) { + listener_->OnHeaderBytesReceived(uncompressed_header_bytes); + } + } + + private: + HeaderByteListenerInterface* listener_; +}; + +} // namespace spdy + +#endif // QUICHE_SPDY_CORE_NO_OP_HEADERS_HANDLER_H_ diff --git a/gquiche/spdy/core/spdy_alt_svc_wire_format.cc b/gquiche/spdy/core/spdy_alt_svc_wire_format.cc index 1a3cd762..f933563e 100644 --- a/gquiche/spdy/core/spdy_alt_svc_wire_format.cc +++ b/gquiche/spdy/core/spdy_alt_svc_wire_format.cc @@ -11,7 +11,6 @@ #include "absl/strings/str_cat.h" #include "gquiche/common/platform/api/quiche_logging.h" -#include "gquiche/spdy/platform/api/spdy_string_utils.h" namespace spdy { @@ -40,15 +39,12 @@ bool ParsePositiveIntegerImpl(absl::string_view::const_iterator c, SpdyAltSvcWireFormat::AlternativeService::AlternativeService() = default; SpdyAltSvcWireFormat::AlternativeService::AlternativeService( - const std::string& protocol_id, - const std::string& host, - uint16_t port, - uint32_t max_age, - VersionVector version) + const std::string& protocol_id, const std::string& host, uint16_t port, + uint32_t max_age_seconds, VersionVector version) : protocol_id(protocol_id), host(host), port(port), - max_age(max_age), + max_age_seconds(max_age_seconds), version(std::move(version)) {} SpdyAltSvcWireFormat::AlternativeService::~AlternativeService() = default; @@ -114,7 +110,7 @@ bool SpdyAltSvcWireFormat::ParseHeaderFieldValue( } ++c; // Parse parameters. - uint32_t max_age = 86400; + uint32_t max_age_seconds = 86400; VersionVector version; absl::string_view::const_iterator parameters_end = std::find(c, value.end(), ','); @@ -148,7 +144,8 @@ bool SpdyAltSvcWireFormat::ParseHeaderFieldValue( return false; } if (parameter_name == "ma") { - if (!ParsePositiveInteger32(parameter_value_begin, c, &max_age)) { + if (!ParsePositiveInteger32(parameter_value_begin, c, + &max_age_seconds)) { return false; } } else if (!is_ietf_format_quic && parameter_name == "v") { @@ -196,16 +193,17 @@ bool SpdyAltSvcWireFormat::ParseHeaderFieldValue( // hq=":443";quic=51303338 // ... will be stored in |versions| as 0x51303338. uint32_t quic_version; - if (!SpdyHexDecodeToUInt32(absl::string_view(&*parameter_value_begin, - c - parameter_value_begin), - &quic_version) || + if (!HexDecodeToUInt32(absl::string_view(&*parameter_value_begin, + c - parameter_value_begin), + &quic_version) || quic_version == 0) { return false; } version.push_back(quic_version); } } - altsvc_vector->emplace_back(protocol_id, host, port, max_age, version); + altsvc_vector->emplace_back(protocol_id, host, port, max_age_seconds, + version); for (; c != value.end() && (*c == ' ' || *c == '\t' || *c == ','); ++c) { } } @@ -267,8 +265,8 @@ std::string SpdyAltSvcWireFormat::SerializeHeaderFieldValue( value.push_back(c); } absl::StrAppend(&value, ":", altsvc.port, "\""); - if (altsvc.max_age != 86400) { - absl::StrAppend(&value, "; ma=", altsvc.max_age); + if (altsvc.max_age_seconds != 86400) { + absl::StrAppend(&value, "; ma=", altsvc.max_age_seconds); } if (!altsvc.version.empty()) { if (is_ietf_format_quic) { @@ -315,12 +313,12 @@ bool SpdyAltSvcWireFormat::PercentDecode(absl::string_view::const_iterator c, return false; } // Network byte order is big-endian. - char decoded = SpdyHexDigitToInt(*c) << 4; + char decoded = HexDigitToInt(*c) << 4; ++c; if (c == end || !std::isxdigit(*c)) { return false; } - decoded += SpdyHexDigitToInt(*c); + decoded += HexDigitToInt(*c); output->push_back(decoded); } return true; @@ -389,4 +387,41 @@ bool SpdyAltSvcWireFormat::ParsePositiveInteger32( return ParsePositiveIntegerImpl(c, end, value); } +// static +char SpdyAltSvcWireFormat::HexDigitToInt(char c) { + QUICHE_DCHECK(std::isxdigit(c)); + + if (std::isdigit(c)) { + return c - '0'; + } + if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } + + return 0; +} + +// static +bool SpdyAltSvcWireFormat::HexDecodeToUInt32(absl::string_view data, + uint32_t* value) { + if (data.empty() || data.length() > 8u) { + return false; + } + + *value = 0; + for (char c : data) { + if (!std::isxdigit(c)) { + return false; + } + + *value <<= 4; + *value += HexDigitToInt(c); + } + + return true; +} + } // namespace spdy diff --git a/gquiche/spdy/core/spdy_alt_svc_wire_format.h b/gquiche/spdy/core/spdy_alt_svc_wire_format.h index aca0c86d..39ef8bcb 100644 --- a/gquiche/spdy/core/spdy_alt_svc_wire_format.h +++ b/gquiche/spdy/core/spdy_alt_svc_wire_format.h @@ -35,15 +35,13 @@ class QUICHE_EXPORT_PRIVATE SpdyAltSvcWireFormat { // Default is 0: invalid port. uint16_t port = 0; // Default is one day. - uint32_t max_age = 86400; + uint32_t max_age_seconds = 86400; // Default is empty: unspecified version. VersionVector version; AlternativeService(); - AlternativeService(const std::string& protocol_id, - const std::string& host, - uint16_t port, - uint32_t max_age, + AlternativeService(const std::string& protocol_id, const std::string& host, + uint16_t port, uint32_t max_age_seconds, VersionVector version); AlternativeService(const AlternativeService& other); ~AlternativeService(); @@ -51,7 +49,7 @@ class QUICHE_EXPORT_PRIVATE SpdyAltSvcWireFormat { bool operator==(const AlternativeService& other) const { return protocol_id == other.protocol_id && host == other.host && port == other.port && version == other.version && - max_age == other.max_age; + max_age_seconds == other.max_age_seconds; } }; // An empty vector means alternative services should be cleared for given @@ -66,21 +64,40 @@ class QUICHE_EXPORT_PRIVATE SpdyAltSvcWireFormat { const AlternativeServiceVector& altsvc_vector); private: + // Forward |*c| over space and tab or until |end| is reached. static void SkipWhiteSpace(absl::string_view::const_iterator* c, absl::string_view::const_iterator end); + // Decode percent-decoded string between |c| and |end| into |*output|. + // Return true on success, false if input is invalid. static bool PercentDecode(absl::string_view::const_iterator c, absl::string_view::const_iterator end, std::string* output); + // Parse the authority part of Alt-Svc between |c| and |end| into |*host| and + // |*port|. Return true on success, false if input is invalid. static bool ParseAltAuthority(absl::string_view::const_iterator c, absl::string_view::const_iterator end, std::string* host, uint16_t* port); + // Parse a positive integer between |c| and |end| into |*value|. + // Return true on success, false if input is not a positive integer or it + // cannot be represented on uint16_t. static bool ParsePositiveInteger16(absl::string_view::const_iterator c, absl::string_view::const_iterator end, uint16_t* value); + // Parse a positive integer between |c| and |end| into |*value|. + // Return true on success, false if input is not a positive integer or it + // cannot be represented on uint32_t. static bool ParsePositiveInteger32(absl::string_view::const_iterator c, absl::string_view::const_iterator end, uint32_t* value); + // Parse |c| as hexadecimal digit, case insensitive. |c| must be [0-9a-fA-F]. + // Output is between 0 and 15. + static char HexDigitToInt(char c); + // Parse |data| as hexadecimal number into |*value|. |data| must only contain + // hexadecimal digits, no "0x" prefix. + // Return true on success, false if input is empty, not valid hexadecimal + // number, or cannot be represented on uint32_t. + static bool HexDecodeToUInt32(absl::string_view data, uint32_t* value); }; } // namespace spdy diff --git a/gquiche/spdy/core/spdy_alt_svc_wire_format_test.cc b/gquiche/spdy/core/spdy_alt_svc_wire_format_test.cc index 0e6845b8..b0db8d8a 100644 --- a/gquiche/spdy/core/spdy_alt_svc_wire_format_test.cc +++ b/gquiche/spdy/core/spdy_alt_svc_wire_format_test.cc @@ -30,18 +30,24 @@ class SpdyAltSvcWireFormatPeer { } static bool ParsePositiveInteger16(absl::string_view::const_iterator c, absl::string_view::const_iterator end, - uint16_t* max_age) { - return SpdyAltSvcWireFormat::ParsePositiveInteger16(c, end, max_age); + uint16_t* max_age_seconds) { + return SpdyAltSvcWireFormat::ParsePositiveInteger16(c, end, + max_age_seconds); } static bool ParsePositiveInteger32(absl::string_view::const_iterator c, absl::string_view::const_iterator end, - uint32_t* max_age) { - return SpdyAltSvcWireFormat::ParsePositiveInteger32(c, end, max_age); + uint32_t* max_age_seconds) { + return SpdyAltSvcWireFormat::ParsePositiveInteger32(c, end, + max_age_seconds); + } + static char HexDigitToInt(char c) { + return SpdyAltSvcWireFormat::HexDigitToInt(c); + } + static bool HexDecodeToUInt32(absl::string_view data, uint32_t* value) { + return SpdyAltSvcWireFormat::HexDecodeToUInt32(data, value); } }; -} // namespace test - namespace { // Generate header field values, possibly with multiply defined parameters and @@ -75,7 +81,7 @@ void FuzzHeaderFieldValue( header_field_value->append(" "); } if (i & 3 << 3) { - expected_altsvc->max_age = 1111; + expected_altsvc->max_age_seconds = 1111; header_field_value->append(";"); if (i & 1 << 3) { header_field_value->append(" "); @@ -110,7 +116,7 @@ void FuzzHeaderFieldValue( } } if (i & 1 << 8) { - expected_altsvc->max_age = 999999999; + expected_altsvc->max_age_seconds = 999999999; header_field_value->append("; Ma=999999999"); } if (i & 1 << 9) { @@ -144,7 +150,7 @@ void FuzzAlternativeService(int i, } expected_header_field_value->append(":42\""); if (i & 1 << 1) { - altsvc->max_age = 1111; + altsvc->max_age_seconds = 1111; expected_header_field_value->append("; ma=1111"); } if (i & 1 << 2) { @@ -161,7 +167,7 @@ TEST(SpdyAltSvcWireFormatTest, DefaultValues) { EXPECT_EQ("", altsvc.protocol_id); EXPECT_EQ("", altsvc.host); EXPECT_EQ(0u, altsvc.port); - EXPECT_EQ(86400u, altsvc.max_age); + EXPECT_EQ(86400u, altsvc.max_age_seconds); EXPECT_TRUE(altsvc.version.empty()); } @@ -192,7 +198,8 @@ TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValue) { EXPECT_EQ(expected_altsvc.protocol_id, altsvc_vector[0].protocol_id); EXPECT_EQ(expected_altsvc.host, altsvc_vector[0].host); EXPECT_EQ(expected_altsvc.port, altsvc_vector[0].port); - EXPECT_EQ(expected_altsvc.max_age, altsvc_vector[0].max_age); + EXPECT_EQ(expected_altsvc.max_age_seconds, + altsvc_vector[0].max_age_seconds); EXPECT_EQ(expected_altsvc.version, altsvc_vector[0].version); // Roundtrip test starting with |altsvc_vector|. @@ -206,7 +213,8 @@ TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValue) { roundtrip_altsvc_vector[0].protocol_id); EXPECT_EQ(expected_altsvc.host, roundtrip_altsvc_vector[0].host); EXPECT_EQ(expected_altsvc.port, roundtrip_altsvc_vector[0].port); - EXPECT_EQ(expected_altsvc.max_age, roundtrip_altsvc_vector[0].max_age); + EXPECT_EQ(expected_altsvc.max_age_seconds, + roundtrip_altsvc_vector[0].max_age_seconds); EXPECT_EQ(expected_altsvc.version, roundtrip_altsvc_vector[0].version); } } @@ -236,7 +244,8 @@ TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValueMultiple) { altsvc_vector[j].protocol_id); EXPECT_EQ(expected_altsvc_vector[j].host, altsvc_vector[j].host); EXPECT_EQ(expected_altsvc_vector[j].port, altsvc_vector[j].port); - EXPECT_EQ(expected_altsvc_vector[j].max_age, altsvc_vector[j].max_age); + EXPECT_EQ(expected_altsvc_vector[j].max_age_seconds, + altsvc_vector[j].max_age_seconds); EXPECT_EQ(expected_altsvc_vector[j].version, altsvc_vector[j].version); } @@ -254,8 +263,8 @@ TEST(SpdyAltSvcWireFormatTest, ParseHeaderFieldValueMultiple) { roundtrip_altsvc_vector[j].host); EXPECT_EQ(expected_altsvc_vector[j].port, roundtrip_altsvc_vector[j].port); - EXPECT_EQ(expected_altsvc_vector[j].max_age, - roundtrip_altsvc_vector[j].max_age); + EXPECT_EQ(expected_altsvc_vector[j].max_age_seconds, + roundtrip_altsvc_vector[j].max_age_seconds); EXPECT_EQ(expected_altsvc_vector[j].version, roundtrip_altsvc_vector[j].version); } @@ -286,7 +295,7 @@ TEST(SpdyAltSvcWireFormatTest, RoundTrip) { EXPECT_EQ(altsvc.protocol_id, parsed_altsvc_vector[0].protocol_id); EXPECT_EQ(altsvc.host, parsed_altsvc_vector[0].host); EXPECT_EQ(altsvc.port, parsed_altsvc_vector[0].port); - EXPECT_EQ(altsvc.max_age, parsed_altsvc_vector[0].max_age); + EXPECT_EQ(altsvc.max_age_seconds, parsed_altsvc_vector[0].max_age_seconds); EXPECT_EQ(altsvc.version, parsed_altsvc_vector[0].version); // Test SerializeHeaderFieldValue(). @@ -321,7 +330,7 @@ TEST(SpdyAltSvcWireFormatTest, RoundTripMultiple) { EXPECT_EQ(expected_it->protocol_id, parsed_it->protocol_id); EXPECT_EQ(expected_it->host, parsed_it->host); EXPECT_EQ(expected_it->port, parsed_it->port); - EXPECT_EQ(expected_it->max_age, parsed_it->max_age); + EXPECT_EQ(expected_it->max_age_seconds, parsed_it->max_age_seconds); EXPECT_EQ(expected_it->version, parsed_it->version); } @@ -388,13 +397,13 @@ TEST(SpdyAltSvcWireFormatTest, ParseTruncatedHeaderFieldValue) { TEST(SpdyAltSvcWireFormatTest, SkipWhiteSpace) { absl::string_view input("a \tb "); absl::string_view::const_iterator c = input.begin(); - test::SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); + SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); ASSERT_EQ(input.begin(), c); ++c; - test::SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); + SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); ASSERT_EQ(input.begin() + 3, c); ++c; - test::SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); + SpdyAltSvcWireFormatPeer::SkipWhiteSpace(&c, input.end()); ASSERT_EQ(input.end(), c); } @@ -402,20 +411,20 @@ TEST(SpdyAltSvcWireFormatTest, SkipWhiteSpace) { TEST(SpdyAltSvcWireFormatTest, PercentDecodeValid) { absl::string_view input(""); std::string output; - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::PercentDecode( - input.begin(), input.end(), &output)); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)); EXPECT_EQ("", output); input = absl::string_view("foo"); output.clear(); - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::PercentDecode( - input.begin(), input.end(), &output)); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)); EXPECT_EQ("foo", output); input = absl::string_view("%2ca%5Cb"); output.clear(); - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::PercentDecode( - input.begin(), input.end(), &output)); + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)); EXPECT_EQ(",a\\b", output); } @@ -425,8 +434,8 @@ TEST(SpdyAltSvcWireFormatTest, PercentDecodeInvalid) { for (const char* invalid_input : invalid_input_array) { absl::string_view input(invalid_input); std::string output; - EXPECT_FALSE(test::SpdyAltSvcWireFormatPeer::PercentDecode( - input.begin(), input.end(), &output)) + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::PercentDecode(input.begin(), + input.end(), &output)) << input; } } @@ -436,19 +445,19 @@ TEST(SpdyAltSvcWireFormatTest, ParseAltAuthorityValid) { absl::string_view input(":42"); std::string host; uint16_t port; - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::ParseAltAuthority( + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( input.begin(), input.end(), &host, &port)); EXPECT_TRUE(host.empty()); EXPECT_EQ(42, port); input = absl::string_view("foo:137"); - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::ParseAltAuthority( + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( input.begin(), input.end(), &host, &port)); EXPECT_EQ("foo", host); EXPECT_EQ(137, port); input = absl::string_view("[2003:8:0:16::509d:9615]:443"); - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::ParseAltAuthority( + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( input.begin(), input.end(), &host, &port)); EXPECT_EQ("[2003:8:0:16::509d:9615]", host); EXPECT_EQ(443, port); @@ -477,7 +486,7 @@ TEST(SpdyAltSvcWireFormatTest, ParseAltAuthorityInvalid) { absl::string_view input(invalid_input); std::string host; uint16_t port; - EXPECT_FALSE(test::SpdyAltSvcWireFormatPeer::ParseAltAuthority( + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::ParseAltAuthority( input.begin(), input.end(), &host, &port)) << input; } @@ -487,12 +496,12 @@ TEST(SpdyAltSvcWireFormatTest, ParseAltAuthorityInvalid) { TEST(SpdyAltSvcWireFormatTest, ParseIntegerValid) { absl::string_view input("3"); uint16_t value; - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( input.begin(), input.end(), &value)); EXPECT_EQ(3, value); input = absl::string_view("1337"); - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( input.begin(), input.end(), &value)); EXPECT_EQ(1337, value); } @@ -504,7 +513,7 @@ TEST(SpdyAltSvcWireFormatTest, ParseIntegerInvalid) { for (const char* invalid_input : invalid_input_array) { absl::string_view input(invalid_input); uint16_t value; - EXPECT_FALSE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( input.begin(), input.end(), &value)) << input; } @@ -515,38 +524,38 @@ TEST(SpdyAltSvcWireFormatTest, ParseIntegerOverflow) { // Largest possible uint16_t value. absl::string_view input("65535"); uint16_t value16; - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( input.begin(), input.end(), &value16)); EXPECT_EQ(65535, value16); // Overflow uint16_t, ParsePositiveInteger16() should return false. input = absl::string_view("65536"); - ASSERT_FALSE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( input.begin(), input.end(), &value16)); // However, even if overflow is not checked for, 65536 overflows to 0, which // returns false anyway. Check for a larger number which overflows to 1. input = absl::string_view("65537"); - ASSERT_FALSE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger16( input.begin(), input.end(), &value16)); // Largest possible uint32_t value. input = absl::string_view("4294967295"); uint32_t value32; - ASSERT_TRUE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( + ASSERT_TRUE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( input.begin(), input.end(), &value32)); EXPECT_EQ(4294967295, value32); // Overflow uint32_t, ParsePositiveInteger32() should return false. input = absl::string_view("4294967296"); - ASSERT_FALSE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( input.begin(), input.end(), &value32)); // However, even if overflow is not checked for, 4294967296 overflows to 0, // which returns false anyway. Check for a larger number which overflows to // 1. input = absl::string_view("4294967297"); - ASSERT_FALSE(test::SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( + ASSERT_FALSE(SpdyAltSvcWireFormatPeer::ParsePositiveInteger32( input.begin(), input.end(), &value32)); } @@ -562,10 +571,68 @@ TEST(SpdyAltSvcWireFormatTest, ParseIPLiteral) { EXPECT_EQ("quic", altsvc_vector[0].protocol_id); EXPECT_EQ("[2003:8:0:16::509d:9615]", altsvc_vector[0].host); EXPECT_EQ(443u, altsvc_vector[0].port); - EXPECT_EQ(60u, altsvc_vector[0].max_age); + EXPECT_EQ(60u, altsvc_vector[0].max_age_seconds); EXPECT_THAT(altsvc_vector[0].version, ::testing::ElementsAre(36, 35)); } +TEST(SpdyAltSvcWireFormatTest, HexDigitToInt) { + EXPECT_EQ(0, SpdyAltSvcWireFormatPeer::HexDigitToInt('0')); + EXPECT_EQ(1, SpdyAltSvcWireFormatPeer::HexDigitToInt('1')); + EXPECT_EQ(2, SpdyAltSvcWireFormatPeer::HexDigitToInt('2')); + EXPECT_EQ(3, SpdyAltSvcWireFormatPeer::HexDigitToInt('3')); + EXPECT_EQ(4, SpdyAltSvcWireFormatPeer::HexDigitToInt('4')); + EXPECT_EQ(5, SpdyAltSvcWireFormatPeer::HexDigitToInt('5')); + EXPECT_EQ(6, SpdyAltSvcWireFormatPeer::HexDigitToInt('6')); + EXPECT_EQ(7, SpdyAltSvcWireFormatPeer::HexDigitToInt('7')); + EXPECT_EQ(8, SpdyAltSvcWireFormatPeer::HexDigitToInt('8')); + EXPECT_EQ(9, SpdyAltSvcWireFormatPeer::HexDigitToInt('9')); + + EXPECT_EQ(10, SpdyAltSvcWireFormatPeer::HexDigitToInt('a')); + EXPECT_EQ(11, SpdyAltSvcWireFormatPeer::HexDigitToInt('b')); + EXPECT_EQ(12, SpdyAltSvcWireFormatPeer::HexDigitToInt('c')); + EXPECT_EQ(13, SpdyAltSvcWireFormatPeer::HexDigitToInt('d')); + EXPECT_EQ(14, SpdyAltSvcWireFormatPeer::HexDigitToInt('e')); + EXPECT_EQ(15, SpdyAltSvcWireFormatPeer::HexDigitToInt('f')); + + EXPECT_EQ(10, SpdyAltSvcWireFormatPeer::HexDigitToInt('A')); + EXPECT_EQ(11, SpdyAltSvcWireFormatPeer::HexDigitToInt('B')); + EXPECT_EQ(12, SpdyAltSvcWireFormatPeer::HexDigitToInt('C')); + EXPECT_EQ(13, SpdyAltSvcWireFormatPeer::HexDigitToInt('D')); + EXPECT_EQ(14, SpdyAltSvcWireFormatPeer::HexDigitToInt('E')); + EXPECT_EQ(15, SpdyAltSvcWireFormatPeer::HexDigitToInt('F')); +} + +TEST(SpdyAltSvcWireFormatTest, HexDecodeToUInt32) { + uint32_t out; + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("0", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("00", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("0000000", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("00000000", &out)); + EXPECT_EQ(0u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("1", &out)); + EXPECT_EQ(1u, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("ffffFFF", &out)); + EXPECT_EQ(0xFFFFFFFu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("fFfFffFf", &out)); + EXPECT_EQ(0xFFFFFFFFu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("01AEF", &out)); + EXPECT_EQ(0x1AEFu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("abcde", &out)); + EXPECT_EQ(0xABCDEu, out); + EXPECT_TRUE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("1234abcd", &out)); + EXPECT_EQ(0x1234ABCDu, out); + + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("", &out)); + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("111111111", &out)); + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("1111111111", &out)); + EXPECT_FALSE(SpdyAltSvcWireFormatPeer::HexDecodeToUInt32("0x1111", &out)); +} + } // namespace +} // namespace test + } // namespace spdy diff --git a/gquiche/spdy/core/spdy_framer.cc b/gquiche/spdy/core/spdy_framer.cc index 58bdc76b..0cd8da2e 100644 --- a/gquiche/spdy/core/spdy_framer.cc +++ b/gquiche/spdy/core/spdy_framer.cc @@ -18,8 +18,6 @@ #include "gquiche/spdy/core/spdy_bitmasks.h" #include "gquiche/spdy/core/spdy_frame_builder.h" #include "gquiche/spdy/core/spdy_frame_reader.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" -#include "gquiche/spdy/platform/api/spdy_string_utils.h" namespace spdy { @@ -303,9 +301,8 @@ size_t SpdyFramer::SpdyFrameIterator::NextFrame(ZeroCopyOutputBuffer* output) { const size_t size_without_block = is_first_frame_ ? GetFrameSizeSansBlock() : kContinuationFrameMinimumSize; - auto encoding = std::make_unique(); - encoder_->Next(kHttp2MaxControlFrameSendSize - size_without_block, - encoding.get()); + std::string encoding = + encoder_->Next(kHttp2MaxControlFrameSendSize - size_without_block); has_next_frame_ = encoder_->HasNext(); if (framer_->debug_visitor_ != nullptr) { @@ -316,14 +313,14 @@ size_t SpdyFramer::SpdyFrameIterator::NextFrame(ZeroCopyOutputBuffer* output) { framer_->debug_visitor_->OnSendCompressedFrame( frame_ir.stream_id(), is_first_frame_ ? frame_ir.frame_type() : SpdyFrameType::CONTINUATION, - header_list_size, size_without_block + encoding->size()); + header_list_size, size_without_block + encoding.size()); } const size_t free_bytes_before = output->BytesFree(); bool ok = false; if (is_first_frame_) { is_first_frame_ = false; - ok = SerializeGivenEncoding(*encoding, output); + ok = SerializeGivenEncoding(encoding, output); } else { SpdyContinuationIR continuation_ir(frame_ir.stream_id()); continuation_ir.take_encoding(std::move(encoding)); @@ -347,7 +344,7 @@ SpdyFramer::SpdyHeaderFrameIterator::SpdyHeaderFrameIterator( SpdyFramer::SpdyHeaderFrameIterator::~SpdyHeaderFrameIterator() = default; const SpdyFrameIR& SpdyFramer::SpdyHeaderFrameIterator::GetIR() const { - return *(headers_ir_.get()); + return *headers_ir_; } size_t SpdyFramer::SpdyHeaderFrameIterator::GetFrameSizeSansBlock() const { @@ -372,7 +369,7 @@ SpdyFramer::SpdyPushPromiseFrameIterator::~SpdyPushPromiseFrameIterator() = default; const SpdyFrameIR& SpdyFramer::SpdyPushPromiseFrameIterator::GetIR() const { - return *(push_promise_ir_.get()); + return *push_promise_ir_; } size_t SpdyFramer::SpdyPushPromiseFrameIterator::GetFrameSizeSansBlock() const { @@ -405,7 +402,7 @@ bool SpdyFramer::SpdyControlFrameIterator::HasNextFrame() const { } const SpdyFrameIR& SpdyFramer::SpdyControlFrameIterator::GetIR() const { - return *(frame_ir_.get()); + return *frame_ir_; } std::unique_ptr SpdyFramer::CreateIterator( @@ -580,7 +577,8 @@ void SpdyFramer::SerializeHeadersBuilderHelper(const SpdyHeadersIR& headers, *size = *size + 5; } - GetHpackEncoder()->EncodeHeaderSet(headers.header_block(), hpack_encoding); + *hpack_encoding = + GetHpackEncoder()->EncodeHeaderBlock(headers.header_block()); *size = *size + hpack_encoding->size(); if (*size > kHttp2MaxControlFrameSendSize) { *size = *size + GetNumberRequiredContinuationFrames(*size) * @@ -673,8 +671,8 @@ void SpdyFramer::SerializePushPromiseBuilderHelper( *size = *size + push_promise.padding_payload_len(); } - GetHpackEncoder()->EncodeHeaderSet(push_promise.header_block(), - hpack_encoding); + *hpack_encoding = + GetHpackEncoder()->EncodeHeaderBlock(push_promise.header_block()); *size = *size + hpack_encoding->size(); if (*size > kHttp2MaxControlFrameSendSize) { *size = *size + GetNumberRequiredContinuationFrames(*size) * @@ -1380,8 +1378,4 @@ size_t SpdyFramer::header_encoder_table_size() const { } } -size_t SpdyFramer::EstimateMemoryUsage() const { - return SpdyEstimateMemoryUsage(hpack_encoder_); -} - } // namespace spdy diff --git a/gquiche/spdy/core/spdy_framer.h b/gquiche/spdy/core/spdy_framer.h index d2837ae9..4bc888b0 100644 --- a/gquiche/spdy/core/spdy_framer.h +++ b/gquiche/spdy/core/spdy_framer.h @@ -236,8 +236,9 @@ class QUICHE_EXPORT_PRIVATE SpdyFramer { // Get (and lazily initialize) the HPACK encoder state. HpackEncoder* GetHpackEncoder(); - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; + // Gets the HPACK encoder state. Returns nullptr if the encoder has not been + // initialized. + const HpackEncoder* GetHpackEncoder() const { return hpack_encoder_.get(); } protected: friend class test::SpdyFramerPeer; diff --git a/gquiche/spdy/core/spdy_framer_test.cc b/gquiche/spdy/core/spdy_framer_test.cc index ade884c7..543cb0d9 100644 --- a/gquiche/spdy/core/spdy_framer_test.cc +++ b/gquiche/spdy/core/spdy_framer_test.cc @@ -16,6 +16,7 @@ #include "absl/base/macros.h" #include "gquiche/common/platform/api/quiche_logging.h" #include "gquiche/common/platform/api/quiche_test.h" +#include "gquiche/common/quiche_text_utils.h" #include "gquiche/spdy/core/array_output_buffer.h" #include "gquiche/spdy/core/mock_spdy_framer_visitor.h" #include "gquiche/spdy/core/recording_headers_handler.h" @@ -24,7 +25,6 @@ #include "gquiche/spdy/core/spdy_frame_reader.h" #include "gquiche/spdy/core/spdy_protocol.h" #include "gquiche/spdy/core/spdy_test_utils.h" -#include "gquiche/spdy/platform/api/spdy_string_utils.h" using ::http2::Http2DecoderAdapter; using ::testing::_; @@ -287,7 +287,8 @@ class TestSpdyVisitor : public SpdyFramerVisitorInterface, QUICHE_VLOG(1) << "OnStreamFrameData(" << stream_id << ", data, " << len << ", " << ") data:\n" - << SpdyHexDump(absl::string_view(data, len)); + << quiche::QuicheTextUtils::HexDump( + absl::string_view(data, len)); EXPECT_EQ(header_stream_id_, stream_id); data_bytes_ += len; @@ -702,7 +703,7 @@ TEST_P(SpdyFramerTest, AcceptMaxFrameSizeSetting) { // DATA frame with maximum allowed payload length. unsigned char kH2FrameData[] = { 0x00, 0x40, 0x00, // Length: 2^14 - 0x00, // Type: HEADERS + 0x00, // Type: DATA 0x00, // Flags: None 0x00, 0x00, 0x00, 0x01, // Stream: 1 0x00, 0x00, 0x00, 0x00, // Junk payload @@ -711,6 +712,7 @@ TEST_P(SpdyFramerTest, AcceptMaxFrameSizeSetting) { SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), sizeof(kH2FrameData), false); + EXPECT_CALL(visitor, OnCommonHeader(1, 16384, 0x0, 0x0)); EXPECT_CALL(visitor, OnDataFrameHeader(1, 1 << 14, false)); EXPECT_CALL(visitor, OnStreamFrameData(1, _, 4)); deframer_.ProcessInput(frame.data(), frame.size()); @@ -726,7 +728,7 @@ TEST_P(SpdyFramerTest, ExceedMaxFrameSizeSetting) { // DATA frame with too large payload length. unsigned char kH2FrameData[] = { 0x00, 0x40, 0x01, // Length: 2^14 + 1 - 0x00, // Type: HEADERS + 0x00, // Type: DATA 0x00, // Flags: None 0x00, 0x00, 0x00, 0x01, // Stream: 1 0x00, 0x00, 0x00, 0x00, // Junk payload @@ -735,6 +737,7 @@ TEST_P(SpdyFramerTest, ExceedMaxFrameSizeSetting) { SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), sizeof(kH2FrameData), false); + EXPECT_CALL(visitor, OnCommonHeader(1, 16385, 0x0, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_OVERSIZED_PAYLOAD, _)); deframer_.ProcessInput(frame.data(), frame.size()); EXPECT_TRUE(deframer_.HasError()); @@ -768,6 +771,7 @@ TEST_P(SpdyFramerTest, OversizedDataPaddingError) { { testing::InSequence seq; + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x0, 0x9)); EXPECT_CALL(visitor, OnDataFrameHeader(1, 5, 1)); EXPECT_CALL(visitor, OnStreamPadding(1, 1)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_PADDING, _)); @@ -801,6 +805,7 @@ TEST_P(SpdyFramerTest, CorrectlySizedDataPaddingNoError) { { testing::InSequence seq; + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x0, 0x8)); EXPECT_CALL(visitor, OnDataFrameHeader(1, 5, false)); EXPECT_CALL(visitor, OnStreamPadLength(1, 4)); EXPECT_CALL(visitor, OnError(_, _)).Times(0); @@ -839,6 +844,7 @@ TEST_P(SpdyFramerTest, OversizedHeadersPaddingError) { SpdySerializedFrame frame(reinterpret_cast(kH2FrameData), sizeof(kH2FrameData), false); + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x1, 0x8)); EXPECT_CALL(visitor, OnHeaders(1, false, 0, 0, false, false, false)); EXPECT_CALL(visitor, OnHeaderFrameStart(1)).Times(1); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_PADDING, _)); @@ -869,6 +875,7 @@ TEST_P(SpdyFramerTest, CorrectlySizedHeadersPaddingNoError) { SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x1, 0x8)); EXPECT_CALL(visitor, OnHeaders(1, false, 0, 0, false, false, false)); EXPECT_CALL(visitor, OnHeaderFrameStart(1)).Times(1); @@ -891,6 +898,7 @@ TEST_P(SpdyFramerTest, DataWithStreamIdZero) { SpdySerializedFrame frame(framer_.SerializeData(data_ir)); // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x0, _)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -913,6 +921,7 @@ TEST_P(SpdyFramerTest, HeadersWithStreamIdZero) { SpdyFramerPeer::SerializeHeaders(&framer_, headers, &output_)); // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x1, _)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -940,6 +949,7 @@ TEST_P(SpdyFramerTest, PriorityWithStreamIdZero) { } // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x2, _)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -964,6 +974,7 @@ TEST_P(SpdyFramerTest, RstStreamWithStreamIdZero) { } // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x3, _)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -993,6 +1004,7 @@ TEST_P(SpdyFramerTest, SettingsWithStreamIdNotZero) { SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(1, 6, 0x4, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -1023,6 +1035,7 @@ TEST_P(SpdyFramerTest, GoawayWithStreamIdNotZero) { SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(1, 10, 0x7, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -1040,8 +1053,7 @@ TEST_P(SpdyFramerTest, ContinuationWithStreamIdZero) { deframer_.set_visitor(&visitor); SpdyContinuationIR continuation(/* stream_id = */ 0); - auto some_nonsense_encoding = - std::make_unique("some nonsense encoding"); + std::string some_nonsense_encoding = "some nonsense encoding"; continuation.take_encoding(std::move(some_nonsense_encoding)); continuation.set_end_headers(true); SpdySerializedFrame frame(framer_.SerializeContinuation(continuation)); @@ -1051,6 +1063,7 @@ TEST_P(SpdyFramerTest, ContinuationWithStreamIdZero) { } // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x9, _)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -1074,6 +1087,7 @@ TEST_P(SpdyFramerTest, PushPromiseWithStreamIdZero) { &framer_, push_promise, use_output_ ? &output_ : nullptr)); // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x5, _)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -1096,6 +1110,7 @@ TEST_P(SpdyFramerTest, PushPromiseWithPromisedStreamIdZero) { SpdySerializedFrame frame(SpdyFramerPeer::SerializePushPromise( &framer_, push_promise, use_output_ ? &output_ : nullptr)); + EXPECT_CALL(visitor, OnCommonHeader(3, _, 0x5, _)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, _)); deframer_.ProcessInput(frame.data(), frame.size()); @@ -1112,10 +1127,9 @@ TEST_P(SpdyFramerTest, MultiValueHeader) { // TODO(jgraettinger): If this pattern appears again, move to test class. Http2HeaderBlock header_set; header_set["name"] = value; - std::string buffer; HpackEncoder encoder; encoder.DisableCompression(); - encoder.EncodeHeaderSet(header_set, &buffer); + std::string buffer = encoder.EncodeHeaderBlock(header_set); // Frame builder with plentiful buffer size. SpdyFrameBuilder frame(1024); frame.BeginNewFrame(SpdyFrameType::HEADERS, @@ -1230,6 +1244,96 @@ TEST_P(SpdyFramerTest, Basic) { EXPECT_EQ(4, visitor.data_frame_count_); } +// Verifies that the decoder stops delivering events after a user error. +TEST_P(SpdyFramerTest, BasicWithError) { + // Send HEADERS frames with PRIORITY and END_HEADERS set. + // frame-format off + const unsigned char kH2Input[] = { + 0x00, 0x00, 0x01, // Length: 1 + 0x01, // Type: HEADERS + 0x04, // Flags: END_HEADERS + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x8c, // :status: 200 + + 0x00, 0x00, 0x0c, // Length: 12 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + 0xde, 0xad, 0xbe, 0xef, // + 0xde, 0xad, 0xbe, 0xef, // + + 0x00, 0x00, 0x06, // Length: 6 + 0x01, // Type: HEADERS + 0x24, // Flags: END_HEADERS|PRIORITY + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x00, 0x00, // Parent: 0 + 0x82, // Weight: 131 + 0x8c, // :status: 200 + + 0x00, 0x00, 0x08, // Length: 8 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0xde, 0xad, 0xbe, 0xef, // Payload + 0xde, 0xad, 0xbe, 0xef, // + + 0x00, 0x00, 0x04, // Length: 4 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0xde, 0xad, 0xbe, 0xef, // Payload + + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x01, // Stream: 1 + 0x00, 0x00, 0x00, 0x08, // Error: CANCEL + + 0x00, 0x00, 0x00, // Length: 0 + 0x00, // Type: DATA + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + + 0x00, 0x00, 0x04, // Length: 4 + 0x03, // Type: RST_STREAM + 0x00, // Flags: none + 0x00, 0x00, 0x00, 0x03, // Stream: 3 + 0x00, 0x00, 0x00, 0x08, // Error: CANCEL + }; + // frame-format on + + testing::StrictMock visitor; + + deframer_.set_visitor(&visitor); + + testing::InSequence s; + EXPECT_CALL(visitor, OnCommonHeader(1, 1, 0x1, 0x4)); + EXPECT_CALL(visitor, OnHeaders(1, false, 0, 0, false, false, true)); + EXPECT_CALL(visitor, OnHeaderFrameStart(1)); + EXPECT_CALL(visitor, OnHeaderFrameEnd(1)); + EXPECT_CALL(visitor, OnCommonHeader(1, 12, 0x0, 0x0)); + EXPECT_CALL(visitor, OnDataFrameHeader(1, 12, false)); + EXPECT_CALL(visitor, OnStreamFrameData(1, _, 12)); + EXPECT_CALL(visitor, OnCommonHeader(3, 6, 0x1, 0x24)); + EXPECT_CALL(visitor, OnHeaders(3, true, 131, 0, false, false, true)); + EXPECT_CALL(visitor, OnHeaderFrameStart(3)); + EXPECT_CALL(visitor, OnHeaderFrameEnd(3)); + EXPECT_CALL(visitor, OnCommonHeader(3, 8, 0x0, 0x0)); + EXPECT_CALL(visitor, OnDataFrameHeader(3, 8, false)) + .WillOnce( + testing::InvokeWithoutArgs([this]() { deframer_.StopProcessing(); })); + // Remaining frames are not processed due to the error. + EXPECT_CALL( + visitor, + OnError(http2::Http2DecoderAdapter::SpdyFramerError::SPDY_STOP_PROCESSING, + "Ignoring further events on this connection.")); + + size_t processed = deframer_.ProcessInput( + reinterpret_cast(kH2Input), sizeof(kH2Input)); + EXPECT_LT(processed, sizeof(kH2Input)); +} + // Test that the FIN flag on a data frame signifies EOF. TEST_P(SpdyFramerTest, FinOnDataFrame) { // Send HEADERS frames with END_HEADERS set. @@ -2315,10 +2419,9 @@ TEST_P(SpdyFramerTest, CreateContinuationUncompressed) { Http2HeaderBlock header_block; header_block["bar"] = "foo"; header_block["foo"] = "bar"; - auto buffer = std::make_unique(); HpackEncoder encoder; encoder.DisableCompression(); - encoder.EncodeHeaderSet(header_block, buffer.get()); + std::string buffer = encoder.EncodeHeaderBlock(header_block); SpdyContinuationIR continuation(/* stream_id = */ 42); continuation.take_encoding(std::move(buffer)); @@ -2363,6 +2466,7 @@ TEST_P(SpdyFramerTest, SendUnexpectedContinuation) { SpdySerializedFrame frame(kH2FrameData, sizeof(kH2FrameData), false); // We shouldn't have to read the whole frame before we signal an error. + EXPECT_CALL(visitor, OnCommonHeader(42, 18, 0x9, 0x4)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME, _)); EXPECT_GT(frame.size(), deframer_.ProcessInput(frame.data(), frame.size())); EXPECT_TRUE(deframer_.HasError()); @@ -3187,6 +3291,8 @@ TEST_P(SpdyFramerTest, ProcessDataFrameWithPadding) { int bytes_consumed = 0; // Send the frame header. + EXPECT_CALL(visitor, + OnCommonHeader(1, kPaddingLen + strlen(data_payload), 0x0, 0x8)); EXPECT_CALL(visitor, OnDataFrameHeader(1, kPaddingLen + strlen(data_payload), false)); QUICHE_CHECK_EQ(kDataFrameMinimumSize, @@ -3836,33 +3942,15 @@ TEST_P(SpdyFramerTest, SpdyFramerErrorToStringTest) { EXPECT_STREQ("CONTROL_PAYLOAD_TOO_LARGE", Http2DecoderAdapter::SpdyFramerErrorToString( Http2DecoderAdapter::SPDY_CONTROL_PAYLOAD_TOO_LARGE)); - EXPECT_STREQ("ZLIB_INIT_FAILURE", - Http2DecoderAdapter::SpdyFramerErrorToString( - Http2DecoderAdapter::SPDY_ZLIB_INIT_FAILURE)); - EXPECT_STREQ("UNSUPPORTED_VERSION", - Http2DecoderAdapter::SpdyFramerErrorToString( - Http2DecoderAdapter::SPDY_UNSUPPORTED_VERSION)); EXPECT_STREQ("DECOMPRESS_FAILURE", Http2DecoderAdapter::SpdyFramerErrorToString( Http2DecoderAdapter::SPDY_DECOMPRESS_FAILURE)); - EXPECT_STREQ("COMPRESS_FAILURE", - Http2DecoderAdapter::SpdyFramerErrorToString( - Http2DecoderAdapter::SPDY_COMPRESS_FAILURE)); - EXPECT_STREQ("GOAWAY_FRAME_CORRUPT", - Http2DecoderAdapter::SpdyFramerErrorToString( - Http2DecoderAdapter::SPDY_GOAWAY_FRAME_CORRUPT)); - EXPECT_STREQ("RST_STREAM_FRAME_CORRUPT", - Http2DecoderAdapter::SpdyFramerErrorToString( - Http2DecoderAdapter::SPDY_RST_STREAM_FRAME_CORRUPT)); EXPECT_STREQ("INVALID_PADDING", Http2DecoderAdapter::SpdyFramerErrorToString( Http2DecoderAdapter::SPDY_INVALID_PADDING)); EXPECT_STREQ("INVALID_DATA_FRAME_FLAGS", Http2DecoderAdapter::SpdyFramerErrorToString( Http2DecoderAdapter::SPDY_INVALID_DATA_FRAME_FLAGS)); - EXPECT_STREQ("INVALID_CONTROL_FRAME_FLAGS", - Http2DecoderAdapter::SpdyFramerErrorToString( - Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_FLAGS)); EXPECT_STREQ("UNEXPECTED_FRAME", Http2DecoderAdapter::SpdyFramerErrorToString( Http2DecoderAdapter::SPDY_UNEXPECTED_FRAME)); @@ -3899,6 +3987,7 @@ TEST_P(SpdyFramerTest, DataFrameFlagsV4) { SpdySerializedFrame frame(framer_.SerializeData(data_ir)); SetFrameFlags(&frame, flags); + EXPECT_CALL(visitor, OnCommonHeader(1, 5, 0x0, flags)); if (flags & ~valid_data_flags) { EXPECT_CALL(visitor, OnError(_, _)); } else { @@ -3959,6 +4048,7 @@ TEST_P(SpdyFramerTest, RstStreamFrameFlags) { } SetFrameFlags(&frame, flags); + EXPECT_CALL(visitor, OnCommonHeader(13, 4, 0x3, flags)); EXPECT_CALL(visitor, OnRstStream(13, ERROR_CODE_CANCEL)); deframer_.ProcessInput(frame.data(), frame.size()); @@ -3989,6 +4079,7 @@ TEST_P(SpdyFramerTest, SettingsFrameFlags) { } SetFrameFlags(&frame, flags); + EXPECT_CALL(visitor, OnCommonHeader(0, 6, 0x4, flags)); if (flags & SETTINGS_FLAG_ACK) { EXPECT_CALL(visitor, OnError(_, _)); } else { @@ -4036,7 +4127,10 @@ TEST_P(SpdyFramerTest, GoawayFrameFlags) { } SetFrameFlags(&frame, flags); + EXPECT_CALL(visitor, OnCommonHeader(0, _, 0x7, flags)); EXPECT_CALL(visitor, OnGoAway(97, ERROR_CODE_NO_ERROR)); + EXPECT_CALL(visitor, OnGoAwayFrameData) + .WillRepeatedly(testing::Return(true)); deframer_.ProcessInput(frame.data(), frame.size()); EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_.state()); @@ -4085,6 +4179,7 @@ TEST_P(SpdyFramerTest, HeadersFrameFlags) { parent_stream_id = 5; exclusive = true; } + EXPECT_CALL(visitor, OnCommonHeader(stream_id, _, 0x1, set_flags)); EXPECT_CALL(visitor, OnHeaders(stream_id, has_priority, weight, parent_stream_id, exclusive, fin, end)); EXPECT_CALL(visitor, OnHeaderFrameStart(57)).Times(1); @@ -4119,6 +4214,7 @@ TEST_P(SpdyFramerTest, PingFrameFlags) { SpdySerializedFrame frame(framer_.SerializePing(SpdyPingIR(42))); SetFrameFlags(&frame, flags); + EXPECT_CALL(visitor, OnCommonHeader(0, 8, 0x6, flags)); EXPECT_CALL(visitor, OnPing(42, flags & PING_FLAG_ACK)); deframer_.ProcessInput(frame.data(), frame.size()); @@ -4144,6 +4240,7 @@ TEST_P(SpdyFramerTest, WindowUpdateFrameFlags) { SpdyWindowUpdateIR(/* stream_id = */ 4, /* delta = */ 1024))); SetFrameFlags(&frame, flags); + EXPECT_CALL(visitor, OnCommonHeader(4, 4, 0x8, flags)); EXPECT_CALL(visitor, OnWindowUpdate(4, 1024)); deframer_.ProcessInput(frame.data(), frame.size()); @@ -4186,6 +4283,8 @@ TEST_P(SpdyFramerTest, PushPromiseFrameFlags) { bool end = flags & PUSH_PROMISE_FLAG_END_PUSH_PROMISE; EXPECT_CALL(debug_visitor, OnReceiveCompressedFrame( client_id, SpdyFrameType::PUSH_PROMISE, _)); + EXPECT_CALL(visitor, OnCommonHeader(client_id, _, 0x5, + flags & ~HEADERS_FLAG_PADDED)); EXPECT_CALL(visitor, OnPushPromise(client_id, promised_id, end)); EXPECT_CALL(visitor, OnHeaderFrameStart(client_id)).Times(1); if (end) { @@ -4221,6 +4320,7 @@ TEST_P(SpdyFramerTest, ContinuationFrameFlags) { OnSendCompressedFrame(42, SpdyFrameType::HEADERS, _, _)); EXPECT_CALL(debug_visitor, OnReceiveCompressedFrame(42, SpdyFrameType::HEADERS, _)); + EXPECT_CALL(visitor, OnCommonHeader(42, _, 0x1, 0)); EXPECT_CALL(visitor, OnHeaders(42, false, 0, 0, false, false, false)); EXPECT_CALL(visitor, OnHeaderFrameStart(42)).Times(1); @@ -4249,6 +4349,7 @@ TEST_P(SpdyFramerTest, ContinuationFrameFlags) { EXPECT_CALL(debug_visitor, OnReceiveCompressedFrame(42, SpdyFrameType::CONTINUATION, _)); + EXPECT_CALL(visitor, OnCommonHeader(42, _, 0x9, flags)); EXPECT_CALL(visitor, OnContinuation(42, flags & HEADERS_FLAG_END_HEADERS)); bool end = flags & HEADERS_FLAG_END_HEADERS; if (end) { @@ -4286,6 +4387,7 @@ TEST_P(SpdyFramerTest, RstStreamStatusBounds) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(1, 4, 0x3, 0x0)); EXPECT_CALL(visitor, OnRstStream(1, ERROR_CODE_NO_ERROR)); deframer_.ProcessInput(reinterpret_cast(kH2RstStreamInvalid), ABSL_ARRAYSIZE(kH2RstStreamInvalid)); @@ -4295,6 +4397,7 @@ TEST_P(SpdyFramerTest, RstStreamStatusBounds) { deframer_.spdy_framer_error()); deframer_.Reset(); + EXPECT_CALL(visitor, OnCommonHeader(1, 4, 0x3, 0x0)); EXPECT_CALL(visitor, OnRstStream(1, ERROR_CODE_INTERNAL_ERROR)); deframer_.ProcessInput( reinterpret_cast(kH2RstStreamNumStatusCodes), @@ -4319,7 +4422,9 @@ TEST_P(SpdyFramerTest, GoAwayStatusBounds) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(0, 10, 0x7, 0x0)); EXPECT_CALL(visitor, OnGoAway(1, ERROR_CODE_INTERNAL_ERROR)); + EXPECT_CALL(visitor, OnGoAwayFrameData).WillRepeatedly(testing::Return(true)); deframer_.ProcessInput(reinterpret_cast(kH2FrameData), ABSL_ARRAYSIZE(kH2FrameData)); EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_.state()); @@ -4343,7 +4448,9 @@ TEST_P(SpdyFramerTest, GoAwayStreamIdBounds) { deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(0, 8, 0x7, 0x0)); EXPECT_CALL(visitor, OnGoAway(0x7fffffff, ERROR_CODE_NO_ERROR)); + EXPECT_CALL(visitor, OnGoAwayFrameData).WillRepeatedly(testing::Return(true)); deframer_.ProcessInput(reinterpret_cast(kH2FrameData), ABSL_ARRAYSIZE(kH2FrameData)); EXPECT_EQ(Http2DecoderAdapter::SPDY_READY_FOR_FRAME, deframer_.state()); @@ -4366,6 +4473,7 @@ TEST_P(SpdyFramerTest, OnAltSvcWithOrigin) { SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; altsvc_vector.push_back(altsvc1); altsvc_vector.push_back(altsvc2); + EXPECT_CALL(visitor, OnCommonHeader(kStreamId, _, 0x0A, 0x0)); EXPECT_CALL(visitor, OnAltSvc(kStreamId, absl::string_view("o_r|g!n"), altsvc_vector)); @@ -4401,6 +4509,7 @@ TEST_P(SpdyFramerTest, OnAltSvcNoOrigin) { SpdyAltSvcWireFormat::AlternativeServiceVector altsvc_vector; altsvc_vector.push_back(altsvc1); altsvc_vector.push_back(altsvc2); + EXPECT_CALL(visitor, OnCommonHeader(kStreamId, _, 0x0A, 0x0)); EXPECT_CALL(visitor, OnAltSvc(kStreamId, absl::string_view(""), altsvc_vector)); @@ -4423,6 +4532,7 @@ TEST_P(SpdyFramerTest, OnAltSvcEmptyProtocolId) { deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(kStreamId, _, 0x0A, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME, _)); @@ -4571,6 +4681,7 @@ TEST_P(SpdyFramerTest, ReadPriorityUpdateFrame) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(0, 7, 0x10, 0x0)); EXPECT_CALL(visitor, OnPriorityUpdate(3, "foo")); deframer_.ProcessInput(kFrameData, sizeof(kFrameData)); EXPECT_FALSE(deframer_.HasError()); @@ -4588,6 +4699,7 @@ TEST_P(SpdyFramerTest, ReadPriorityUpdateFrameWithEmptyPriorityFieldValue) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(0, 4, 0x10, 0x0)); EXPECT_CALL(visitor, OnPriorityUpdate(3, "")); deframer_.ProcessInput(kFrameData, sizeof(kFrameData)); EXPECT_FALSE(deframer_.HasError()); @@ -4604,6 +4716,7 @@ TEST_P(SpdyFramerTest, PriorityUpdateFrameWithEmptyPayload) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(0, 0, 0x10, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, _)); deframer_.ProcessInput(kFrameData, sizeof(kFrameData)); @@ -4623,6 +4736,7 @@ TEST_P(SpdyFramerTest, PriorityUpdateFrameWithShortPayload) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(0, 2, 0x10, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_CONTROL_FRAME_SIZE, _)); deframer_.ProcessInput(kFrameData, sizeof(kFrameData)); @@ -4641,6 +4755,7 @@ TEST_P(SpdyFramerTest, PriorityUpdateFrameOnIncorrectStream) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(1, 4, 0x10, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); deframer_.ProcessInput(kFrameData, sizeof(kFrameData)); EXPECT_TRUE(deframer_.HasError()); @@ -4658,6 +4773,7 @@ TEST_P(SpdyFramerTest, PriorityUpdateFramePrioritizingIncorrectStream) { testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(0, 4, 0x10, 0x0)); EXPECT_CALL(visitor, OnError(Http2DecoderAdapter::SPDY_INVALID_STREAM_ID, _)); deframer_.ProcessInput(kFrameData, sizeof(kFrameData)); EXPECT_TRUE(deframer_.HasError()); @@ -4677,6 +4793,7 @@ TEST_P(SpdyFramerTest, ReadPriority) { } testing::StrictMock visitor; deframer_.set_visitor(&visitor); + EXPECT_CALL(visitor, OnCommonHeader(3, 5, 0x2, 0x0)); EXPECT_CALL(visitor, OnPriority(3, 1, 256, false)); deframer_.ProcessInput(frame.data(), frame.size()); diff --git a/gquiche/spdy/core/spdy_header_block.cc b/gquiche/spdy/core/spdy_header_block.cc index b6e96c7b..6a4824f3 100644 --- a/gquiche/spdy/core/spdy_header_block.cc +++ b/gquiche/spdy/core/spdy_header_block.cc @@ -11,7 +11,6 @@ #include "absl/strings/str_cat.h" #include "gquiche/common/platform/api/quiche_logging.h" -#include "gquiche/spdy/platform/api/spdy_estimate_memory_usage.h" namespace spdy { namespace { @@ -299,12 +298,6 @@ void Http2HeaderBlock::AppendValueOrAddHeader(const absl::string_view key, iter->second.Append(storage_.Write(value)); } -size_t Http2HeaderBlock::EstimateMemoryUsage() const { - // TODO(xunjieli): https://crbug.com/669108. Also include |map_| when EMU() - // supports linked_hash_map. - return SpdyEstimateMemoryUsage(storage_); -} - void Http2HeaderBlock::AppendHeader(const absl::string_view key, const absl::string_view value) { auto backed_key = WriteKey(key); diff --git a/gquiche/spdy/core/spdy_header_block.h b/gquiche/spdy/core/spdy_header_block.h index 49da94ac..febb9898 100644 --- a/gquiche/spdy/core/spdy_header_block.h +++ b/gquiche/spdy/core/spdy_header_block.h @@ -14,14 +14,11 @@ #include #include "absl/base/attributes.h" -#include "absl/hash/hash.h" -#include "absl/strings/ascii.h" -#include "absl/strings/match.h" -#include "absl/strings/string_view.h" #include "gquiche/common/platform/api/quiche_export.h" #include "gquiche/common/platform/api/quiche_logging.h" +#include "gquiche/common/quiche_linked_hash_map.h" +#include "gquiche/common/quiche_text_utils.h" #include "gquiche/spdy/core/spdy_header_storage.h" -#include "gquiche/spdy/platform/api/spdy_containers.h" namespace spdy { @@ -94,24 +91,9 @@ class QUICHE_EXPORT_PRIVATE Http2HeaderBlock { size_t separator_size_ = 0; }; - struct StringPieceCaseHash { - size_t operator()(absl::string_view data) const { - std::string lower = absl::AsciiStrToLower(data); - absl::Hash hasher; - return hasher(lower); - } - }; - - struct StringPieceCaseEqual { - bool operator()(absl::string_view piece1, absl::string_view piece2) const { - return absl::EqualsIgnoreCase(piece1, piece2); - } - }; - - typedef SpdyLinkedHashMap + typedef quiche::QuicheLinkedHashMap MapType; public: @@ -201,6 +183,7 @@ class QUICHE_EXPORT_PRIVATE Http2HeaderBlock { const_iterator find(absl::string_view key) const { return wrap_const_iterator(map_.find(key)); } + bool contains(absl::string_view key) const { return find(key) != end(); } void erase(absl::string_view key); // Clears both our MapType member and the memory used to hold headers. @@ -262,9 +245,6 @@ class QUICHE_EXPORT_PRIVATE Http2HeaderBlock { // Allows either lookup or mutation of the value associated with a key. ABSL_MUST_USE_RESULT ValueProxy operator[](const absl::string_view key); - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const; - size_t TotalBytesUsed() const { return key_size_ + value_size_; } private: diff --git a/gquiche/spdy/core/spdy_header_block_test.cc b/gquiche/spdy/core/spdy_header_block_test.cc index 30b6420f..9126a269 100644 --- a/gquiche/spdy/core/spdy_header_block_test.cc +++ b/gquiche/spdy/core/spdy_header_block_test.cc @@ -33,6 +33,7 @@ TEST(Http2HeaderBlockTest, EmptyBlock) { EXPECT_TRUE(block.empty()); EXPECT_EQ(0u, block.size()); EXPECT_EQ(block.end(), block.find("foo")); + EXPECT_FALSE(block.contains("foo")); EXPECT_TRUE(block.end() == block.begin()); // Should have no effect. @@ -83,6 +84,7 @@ TEST(Http2HeaderBlockTest, AddHeaders) { std::string qux("qux"); EXPECT_EQ("qux2", block[qux]); ASSERT_NE(block.end(), block.find("key")); + ASSERT_TRUE(block.contains("key")); EXPECT_EQ(Pair("key", "value"), *block.find("key")); block.erase("key"); diff --git a/gquiche/spdy/core/spdy_header_storage.h b/gquiche/spdy/core/spdy_header_storage.h index 0fa63f6f..97db94c5 100644 --- a/gquiche/spdy/core/spdy_header_storage.h +++ b/gquiche/spdy/core/spdy_header_storage.h @@ -43,8 +43,6 @@ class QUICHE_EXPORT_PRIVATE SpdyHeaderStorage { size_t bytes_allocated() const { return arena_.status().bytes_allocated(); } - size_t EstimateMemoryUsage() const { return bytes_allocated(); } - private: SpdySimpleArena arena_; }; diff --git a/gquiche/spdy/core/spdy_intrusive_list.h b/gquiche/spdy/core/spdy_intrusive_list.h index 2d9f749f..3af20538 100644 --- a/gquiche/spdy/core/spdy_intrusive_list.h +++ b/gquiche/spdy/core/spdy_intrusive_list.h @@ -267,9 +267,10 @@ template class SpdyIntrusiveList { public: typedef std::iterator base; - iterator_impl() : link_(nullptr) {} + iterator_impl() = default; iterator_impl(QualifiedLinkT* link) : link_(link) {} - iterator_impl(const iterator_impl& x) : link_(x.link_) {} + iterator_impl(const iterator_impl& x) = default; + iterator_impl& operator=(const iterator_impl& x) = default; // Allow converting and comparing across iterators where the pointer // assignment and comparisons (respectively) are allowed. @@ -310,7 +311,7 @@ template class SpdyIntrusiveList { // Ensure iterators can access other iterators node directly. template friend class iterator_impl; - QualifiedLinkT* link_; + QualifiedLinkT* link_ = nullptr; }; // This bare link acts as the sentinel node. diff --git a/gquiche/spdy/core/spdy_protocol.cc b/gquiche/spdy/core/spdy_protocol.cc index cbd85c99..aa410654 100644 --- a/gquiche/spdy/core/spdy_protocol.cc +++ b/gquiche/spdy/core/spdy_protocol.cc @@ -446,7 +446,6 @@ size_t SpdyGoAwayIR::size() const { SpdyContinuationIR::SpdyContinuationIR(SpdyStreamId stream_id) : SpdyFrameIR(stream_id), end_headers_(false) { - encoding_ = std::make_unique(); } SpdyContinuationIR::~SpdyContinuationIR() = default; diff --git a/gquiche/spdy/core/spdy_protocol.h b/gquiche/spdy/core/spdy_protocol.h index ba4b94ac..fa8e2615 100644 --- a/gquiche/spdy/core/spdy_protocol.h +++ b/gquiche/spdy/core/spdy_protocol.h @@ -260,7 +260,7 @@ QUICHE_EXPORT_PRIVATE bool IsValidHTTP2FrameStreamId( SpdyFrameType frame_type_field); // Serialize |frame_type| to string for logging/debugging. -const char* FrameTypeToString(SpdyFrameType frame_type); +QUICHE_EXPORT_PRIVATE const char* FrameTypeToString(SpdyFrameType frame_type); // If |wire_setting_id| is the on-the-wire representation of a defined SETTINGS // parameter, parse it to |*setting_id| and return true. @@ -328,7 +328,7 @@ const int32_t kInitialStreamWindowSize = 64 * 1024 - 1; // Initial window size for a session in bytes. const int32_t kInitialSessionWindowSize = 64 * 1024 - 1; // The NPN string for HTTP2, "h2". -extern const char* const kHttp2Npn; +QUICHE_EXPORT_PRIVATE extern const char* const kHttp2Npn; // An estimate size of the HPACK overhead for each header field. 1 bytes for // indexed literal, 1 bytes for key literal and length encoding, and 2 bytes for // value literal and length encoding. @@ -351,7 +351,7 @@ QUICHE_EXPORT_PRIVATE size_t GetNumberRequiredContinuationFrames(size_t size); // exclusive bit}. Templated to allow for use by QUIC code; SPDY and HTTP/2 // code should use the concrete type instantiation SpdyStreamPrecedence. template -class StreamPrecedence { +class QUICHE_EXPORT_PRIVATE StreamPrecedence { public: // Constructs instance that is a SPDY 3.x priority. Clamps priority value to // the valid range [0, 7]. @@ -427,7 +427,7 @@ class StreamPrecedence { } private: - struct Http2StreamDependency { + struct QUICHE_EXPORT_PRIVATE Http2StreamDependency { StreamIdType parent_id; int weight; bool is_exclusive; @@ -827,14 +827,12 @@ class QUICHE_EXPORT_PRIVATE SpdyContinuationIR : public SpdyFrameIR { bool end_headers() const { return end_headers_; } void set_end_headers(bool end_headers) { end_headers_ = end_headers; } - const std::string& encoding() const { return *encoding_; } - void take_encoding(std::unique_ptr encoding) { - encoding_ = std::move(encoding); - } + const std::string& encoding() const { return encoding_; } + void take_encoding(std::string encoding) { encoding_ = std::move(encoding); } size_t size() const override; private: - std::unique_ptr encoding_; + std::string encoding_; bool end_headers_; }; @@ -920,7 +918,7 @@ class QUICHE_EXPORT_PRIVATE SpdyPriorityUpdateIR : public SpdyFrameIR { std::string priority_field_value_; }; -struct AcceptChOriginValuePair { +struct QUICHE_EXPORT_PRIVATE AcceptChOriginValuePair { std::string origin; std::string value; bool operator==(const AcceptChOriginValuePair& rhs) const { @@ -1061,9 +1059,6 @@ class QUICHE_EXPORT_PRIVATE SpdySerializedFrame { return buffer; } - // Returns the estimate of dynamically allocated memory in bytes. - size_t EstimateMemoryUsage() const { return owns_buffer_ ? size_ : 0; } - protected: char* frame_; diff --git a/net/base/io_buffer.cc b/net/base/io_buffer.cc index e7091589..75a45fac 100644 --- a/net/base/io_buffer.cc +++ b/net/base/io_buffer.cc @@ -1,78 +1,78 @@ -//copy from chromium/net/base/io_buffer.cc - -#include "net/base/io_buffer.h" - -#include "gquiche/common/platform/api/quiche_logging.h" - -using namespace quic; - -namespace net { - -// TODO(eroman): IOBuffer is being converted to require buffer sizes and offsets -// be specified as "size_t" rather than "int" (crbug.com/488553). To facilitate -// this move (since LOTS of code needs to be updated), both "size_t" and "int -// are being accepted. When using "size_t" this function ensures that it can be -// safely converted to an "int" without truncation. - -IOBuffer::IOBuffer() : data_(nullptr) {} - -IOBuffer::IOBuffer(size_t buffer_size) { - data_ = new char[buffer_size]; -} - -IOBuffer::IOBuffer(char* data) - : data_(data) { -} - -IOBuffer::~IOBuffer() { - delete[] data_; - data_ = nullptr; -} - -IOBufferWithSize::IOBufferWithSize(size_t size) : IOBuffer(size), size_(size) { - // Note: Size check is done in superclass' constructor. -} - -IOBufferWithSize::IOBufferWithSize(char* data, size_t size) - : IOBuffer(data), size_(size) { -} - -IOBufferWithSize::~IOBufferWithSize() = default; - - -DrainableIOBuffer::DrainableIOBuffer( - QuicReferenceCountedPointer base, int size) - : IOBuffer(base->data()), base_(std::move(base)), size_(size), used_(0) { -} - -DrainableIOBuffer::DrainableIOBuffer( - QuicReferenceCountedPointer base, size_t size) - : IOBuffer(base->data()), base_(std::move(base)), size_(size), used_(0) { -} - -void DrainableIOBuffer::DidConsume(int bytes) { - SetOffset(used_ + bytes); -} - -int DrainableIOBuffer::BytesRemaining() const { - return size_ - used_; -} - -// Returns the number of consumed bytes. -int DrainableIOBuffer::BytesConsumed() const { - return used_; -} - -void DrainableIOBuffer::SetOffset(int bytes) { - QUICHE_DCHECK_GE(bytes, 0); - QUICHE_DCHECK_LE(bytes, size_); - used_ = bytes; - data_ = base_->data() + used_; -} - -DrainableIOBuffer::~DrainableIOBuffer() { - // The buffer is owned by the |base_| instance. - data_ = nullptr; -} - -} // namespace net +//copy from chromium/net/base/io_buffer.cc + +#include "net/base/io_buffer.h" + +#include "gquiche/common/platform/api/quiche_logging.h" + +using namespace quic; + +namespace net { + +// TODO(eroman): IOBuffer is being converted to require buffer sizes and offsets +// be specified as "size_t" rather than "int" (crbug.com/488553). To facilitate +// this move (since LOTS of code needs to be updated), both "size_t" and "int +// are being accepted. When using "size_t" this function ensures that it can be +// safely converted to an "int" without truncation. + +IOBuffer::IOBuffer() : data_(nullptr) {} + +IOBuffer::IOBuffer(size_t buffer_size) { + data_ = new char[buffer_size]; +} + +IOBuffer::IOBuffer(char* data) + : data_(data) { +} + +IOBuffer::~IOBuffer() { + delete[] data_; + data_ = nullptr; +} + +IOBufferWithSize::IOBufferWithSize(size_t size) : IOBuffer(size), size_(size) { + // Note: Size check is done in superclass' constructor. +} + +IOBufferWithSize::IOBufferWithSize(char* data, size_t size) + : IOBuffer(data), size_(size) { +} + +IOBufferWithSize::~IOBufferWithSize() = default; + + +DrainableIOBuffer::DrainableIOBuffer( + QuicReferenceCountedPointer base, int size) + : IOBuffer(base->data()), base_(std::move(base)), size_(size), used_(0) { +} + +DrainableIOBuffer::DrainableIOBuffer( + QuicReferenceCountedPointer base, size_t size) + : IOBuffer(base->data()), base_(std::move(base)), size_(size), used_(0) { +} + +void DrainableIOBuffer::DidConsume(int bytes) { + SetOffset(used_ + bytes); +} + +int DrainableIOBuffer::BytesRemaining() const { + return size_ - used_; +} + +// Returns the number of consumed bytes. +int DrainableIOBuffer::BytesConsumed() const { + return used_; +} + +void DrainableIOBuffer::SetOffset(int bytes) { + QUICHE_DCHECK_GE(bytes, 0); + QUICHE_DCHECK_LE(bytes, size_); + used_ = bytes; + data_ = base_->data() + used_; +} + +DrainableIOBuffer::~DrainableIOBuffer() { + // The buffer is owned by the |base_| instance. + data_ = nullptr; +} + +} // namespace net diff --git a/net/base/io_buffer.h b/net/base/io_buffer.h index c7a906bb..2adc6af2 100644 --- a/net/base/io_buffer.h +++ b/net/base/io_buffer.h @@ -1,104 +1,104 @@ -// io_buffer is derived from chromium/net/base/io_buffer - -#ifndef QUICHE_NET_BASE_IO_BUFFER_H_ -#define QUICHE_NET_BASE_IO_BUFFER_H_ - -#include - -#include -#include - -#include "gquiche/quic/platform/api/quic_reference_counted.h" - -namespace net{ - -// This is a simplified version of chromium net::IOBuffer. That is to say, We -//take of reference count of that. -class IOBuffer : public quic::QuicReferenceCounted { - public: - IOBuffer(); - - explicit IOBuffer(size_t buffer_size); - - char* data() const { return data_; } - - protected: - - // Only allow derived classes to specify data_. - // In all other cases, we own data_, and must delete it at destruction time. - explicit IOBuffer(char* data); - - virtual ~IOBuffer(); - - char* data_; -}; - -// This version stores the size of the buffer so that the creator of the object -// doesn't have to keep track of that value. -// NOTE: This doesn't mean that we want to stop sending the size as an explicit -// argument to IO functions. Please keep using IOBuffer* for API declarations. -class IOBufferWithSize : public IOBuffer { - public: - explicit IOBufferWithSize(size_t size); - - int size() const { return size_; } - - protected: - // Purpose of this constructor is to give a subclass access to the base class - // constructor IOBuffer(char*) thus allowing subclass to use underlying - // memory it does not own. - IOBufferWithSize(char* data, size_t size); - ~IOBufferWithSize() override; - - int size_; -}; - -// This version wraps an existing IOBuffer and provides convenient functions -// to progressively read all the data. -// -// DrainableIOBuffer is useful when you have an IOBuffer that contains data -// to be written progressively, and Write() function takes an IOBuffer rather -// than char*. DrainableIOBuffer can be used as follows: -// -// // payload is the IOBuffer containing the data to be written. -// buf = base::MakeRefCounted(payload, payload_size); -// -// while (buf->BytesRemaining() > 0) { -// // Write() takes an IOBuffer. If it takes char*, we could -// // simply use the regular IOBuffer like payload->data() + offset. -// int bytes_written = Write(buf, buf->BytesRemaining()); -// buf->DidConsume(bytes_written); -// } -// -class DrainableIOBuffer : public IOBuffer { - public: - // TODO(eroman): Deprecated. Use the size_t flavor instead. crbug.com/488553 - DrainableIOBuffer(quic::QuicReferenceCountedPointer base, int size); - DrainableIOBuffer(quic::QuicReferenceCountedPointer base, size_t size); - - // DidConsume() changes the |data_| pointer so that |data_| always points - // to the first unconsumed byte. - void DidConsume(int bytes); - - // Returns the number of unconsumed bytes. - int BytesRemaining() const; - - // Returns the number of consumed bytes. - int BytesConsumed() const; - - // Seeks to an arbitrary point in the buffer. The notion of bytes consumed - // and remaining are updated appropriately. - void SetOffset(int bytes); - - int size() const { return size_; } - - private: - ~DrainableIOBuffer() override; - - quic::QuicReferenceCountedPointer base_; - int size_; - int used_; -}; - -} // namespace net -#endif // QUICHE_NET_BASE_IO_BUFFER_H_ +// io_buffer is derived from chromium/net/base/io_buffer + +#ifndef QUICHE_NET_BASE_IO_BUFFER_H_ +#define QUICHE_NET_BASE_IO_BUFFER_H_ + +#include + +#include +#include + +#include "gquiche/quic/platform/api/quic_reference_counted.h" + +namespace net{ + +// This is a simplified version of chromium net::IOBuffer. That is to say, We +//take of reference count of that. +class IOBuffer : public quic::QuicReferenceCounted { + public: + IOBuffer(); + + explicit IOBuffer(size_t buffer_size); + + char* data() const { return data_; } + + protected: + + // Only allow derived classes to specify data_. + // In all other cases, we own data_, and must delete it at destruction time. + explicit IOBuffer(char* data); + + virtual ~IOBuffer(); + + char* data_; +}; + +// This version stores the size of the buffer so that the creator of the object +// doesn't have to keep track of that value. +// NOTE: This doesn't mean that we want to stop sending the size as an explicit +// argument to IO functions. Please keep using IOBuffer* for API declarations. +class IOBufferWithSize : public IOBuffer { + public: + explicit IOBufferWithSize(size_t size); + + int size() const { return size_; } + + protected: + // Purpose of this constructor is to give a subclass access to the base class + // constructor IOBuffer(char*) thus allowing subclass to use underlying + // memory it does not own. + IOBufferWithSize(char* data, size_t size); + ~IOBufferWithSize() override; + + int size_; +}; + +// This version wraps an existing IOBuffer and provides convenient functions +// to progressively read all the data. +// +// DrainableIOBuffer is useful when you have an IOBuffer that contains data +// to be written progressively, and Write() function takes an IOBuffer rather +// than char*. DrainableIOBuffer can be used as follows: +// +// // payload is the IOBuffer containing the data to be written. +// buf = base::MakeRefCounted(payload, payload_size); +// +// while (buf->BytesRemaining() > 0) { +// // Write() takes an IOBuffer. If it takes char*, we could +// // simply use the regular IOBuffer like payload->data() + offset. +// int bytes_written = Write(buf, buf->BytesRemaining()); +// buf->DidConsume(bytes_written); +// } +// +class DrainableIOBuffer : public IOBuffer { + public: + // TODO(eroman): Deprecated. Use the size_t flavor instead. crbug.com/488553 + DrainableIOBuffer(quic::QuicReferenceCountedPointer base, int size); + DrainableIOBuffer(quic::QuicReferenceCountedPointer base, size_t size); + + // DidConsume() changes the |data_| pointer so that |data_| always points + // to the first unconsumed byte. + void DidConsume(int bytes); + + // Returns the number of unconsumed bytes. + int BytesRemaining() const; + + // Returns the number of consumed bytes. + int BytesConsumed() const; + + // Seeks to an arbitrary point in the buffer. The notion of bytes consumed + // and remaining are updated appropriately. + void SetOffset(int bytes); + + int size() const { return size_; } + + private: + ~DrainableIOBuffer() override; + + quic::QuicReferenceCountedPointer base_; + int size_; + int used_; +}; + +} // namespace net +#endif // QUICHE_NET_BASE_IO_BUFFER_H_ diff --git a/net/http/http_request_headers.cc b/net/http/http_request_headers.cc index e4457657..62d956e0 100644 --- a/net/http/http_request_headers.cc +++ b/net/http/http_request_headers.cc @@ -1,219 +1,219 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -// Copy from chromium/net/http_request_headers, and make some changes to compatible with quic_server_src -#include - -#include "base/strings/stringprintf.h" -#include "gquiche/quic/platform/api/quic_logging.h" -#include "googleurl/base/strings/string_split.h" -#include "googleurl/base/strings/string_util.h" -#include "net/http/http_util.h" -#include "net/http/http_request_headers.h" - -namespace net { - -const char HttpRequestHeaders::kConnectMethod[] = "CONNECT"; -const char HttpRequestHeaders::kGetMethod[] = "GET"; -const char HttpRequestHeaders::kHeadMethod[] = "HEAD"; -const char HttpRequestHeaders::kOptionsMethod[] = "OPTIONS"; -const char HttpRequestHeaders::kPostMethod[] = "POST"; -const char HttpRequestHeaders::kTraceMethod[] = "TRACE"; -const char HttpRequestHeaders::kTrackMethod[] = "TRACK"; -const char HttpRequestHeaders::kAccept[] = "Accept"; -const char HttpRequestHeaders::kAcceptCharset[] = "Accept-Charset"; -const char HttpRequestHeaders::kAcceptEncoding[] = "Accept-Encoding"; -const char HttpRequestHeaders::kAcceptLanguage[] = "Accept-Language"; -const char HttpRequestHeaders::kAuthorization[] = "Authorization"; -const char HttpRequestHeaders::kCacheControl[] = "Cache-Control"; -const char HttpRequestHeaders::kConnection[] = "Connection"; -const char HttpRequestHeaders::kContentLength[] = "Content-Length"; -const char HttpRequestHeaders::kContentType[] = "Content-Type"; -const char HttpRequestHeaders::kCookie[] = "Cookie"; -const char HttpRequestHeaders::kHost[] = "Host"; -const char HttpRequestHeaders::kIfMatch[] = "If-Match"; -const char HttpRequestHeaders::kIfModifiedSince[] = "If-Modified-Since"; -const char HttpRequestHeaders::kIfNoneMatch[] = "If-None-Match"; -const char HttpRequestHeaders::kIfRange[] = "If-Range"; -const char HttpRequestHeaders::kIfUnmodifiedSince[] = "If-Unmodified-Since"; -const char HttpRequestHeaders::kOrigin[] = "Origin"; -const char HttpRequestHeaders::kPragma[] = "Pragma"; -const char HttpRequestHeaders::kProxyAuthorization[] = "Proxy-Authorization"; -const char HttpRequestHeaders::kProxyConnection[] = "Proxy-Connection"; -const char HttpRequestHeaders::kRange[] = "Range"; -const char HttpRequestHeaders::kReferer[] = "Referer"; -const char HttpRequestHeaders::kTransferEncoding[] = "Transfer-Encoding"; -const char HttpRequestHeaders::kUserAgent[] = "User-Agent"; - -HttpRequestHeaders::HeaderKeyValuePair::HeaderKeyValuePair() = default; - -HttpRequestHeaders::HeaderKeyValuePair::HeaderKeyValuePair( - const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value) - : key(key.data(), key.size()), value(value.data(), value.size()) {} - -HttpRequestHeaders::Iterator::Iterator(const HttpRequestHeaders& headers) - : started_(false), - curr_(headers.headers_.begin()), - end_(headers.headers_.end()) {} - -HttpRequestHeaders::Iterator::~Iterator() = default; - -bool HttpRequestHeaders::Iterator::GetNext() { - if (!started_) { - started_ = true; - return curr_ != end_; - } - - if (curr_ == end_) - return false; - - ++curr_; - return curr_ != end_; -} - -HttpRequestHeaders::HttpRequestHeaders() = default; -HttpRequestHeaders::HttpRequestHeaders(const HttpRequestHeaders& other) = - default; -HttpRequestHeaders::HttpRequestHeaders(HttpRequestHeaders&& other) = default; -HttpRequestHeaders::~HttpRequestHeaders() = default; - -HttpRequestHeaders& HttpRequestHeaders::operator=( - const HttpRequestHeaders& other) = default; -HttpRequestHeaders& HttpRequestHeaders::operator=(HttpRequestHeaders&& other) = - default; - -bool HttpRequestHeaders::GetHeader(const gurl_base::StringPiece& key, - std::string* out) const { - auto it = FindHeader(key); - if (it == headers_.end()) - return false; - out->assign(it->value); - return true; -} - -void HttpRequestHeaders::Clear() { - headers_.clear(); -} - -void HttpRequestHeaders::SetHeader(const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value) { - // Invalid header names or values could mean clients can attach - // browser-internal headers. - QUICHE_DCHECK(HttpUtil::IsValidHeaderName(key)) << key; - QUICHE_DCHECK(HttpUtil::IsValidHeaderValue(value)) << key << ":" << value; - SetHeaderInternal(key, value); -} - -void HttpRequestHeaders::SetHeaderIfMissing(const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value) { - // Invalid header names or values could mean clients can attach - // browser-internal headers. - QUICHE_DCHECK(HttpUtil::IsValidHeaderName(key)); - QUICHE_DCHECK(HttpUtil::IsValidHeaderValue(value)); - auto it = FindHeader(key); - if (it == headers_.end()) - headers_.push_back(HeaderKeyValuePair(key, value)); -} - -void HttpRequestHeaders::RemoveHeader(const gurl_base::StringPiece& key) { - auto it = FindHeader(key); - if (it != headers_.end()) - headers_.erase(it); -} - -void HttpRequestHeaders::AddHeaderFromString( - const gurl_base::StringPiece& header_line) { - QUICHE_DCHECK_EQ(std::string::npos, header_line.find("\r\n")) - << "\"" << header_line << "\" contains CRLF."; - - const std::string::size_type key_end_index = header_line.find(":"); - if (key_end_index == std::string::npos) { - QUIC_LOG(DFATAL) << "\"" << header_line << "\" is missing colon delimiter."; - return; - } - - if (key_end_index == 0) { - QUIC_LOG(DFATAL) << "\"" << header_line << "\" is missing header key."; - return; - } - - const gurl_base::StringPiece header_key(header_line.data(), key_end_index); - if (!HttpUtil::IsValidHeaderName(header_key)) { - QUIC_LOG(DFATAL) << "\"" << header_line << "\" has invalid header key."; - return; - } - - const std::string::size_type value_index = key_end_index + 1; - - if (value_index < header_line.size()) { - gurl_base::StringPiece header_value(header_line.data() + value_index, - header_line.size() - value_index); - header_value = HttpUtil::TrimLWS(header_value); - if (!HttpUtil::IsValidHeaderValue(header_value)) { - QUIC_LOG(DFATAL) << "\"" << header_line << "\" has invalid header value."; - return; - } - SetHeader(header_key, header_value); - } else if (value_index == header_line.size()) { - SetHeader(header_key, ""); - } else { - QUIC_NOTREACHED(); - } -} - -void HttpRequestHeaders::AddHeadersFromString( - const gurl_base::StringPiece& headers) { - for (const gurl_base::StringPiece& header : gurl_base::SplitStringPieceUsingSubstr( - headers, "\r\n", gurl_base::TRIM_WHITESPACE, gurl_base::SPLIT_WANT_NONEMPTY)) { - AddHeaderFromString(header); - } -} - -void HttpRequestHeaders::MergeFrom(const HttpRequestHeaders& other) { - for (auto it = other.headers_.begin(); it != other.headers_.end(); ++it) { - SetHeader(it->key, it->value); - } -} - -std::string HttpRequestHeaders::ToString() const { - std::string output; - for (auto it = headers_.begin(); it != headers_.end(); ++it) { - base::StringAppendF(&output, "%s: %s\r\n", it->key.c_str(), - it->value.c_str()); - } - output.append("\r\n"); - return output; -} - -HttpRequestHeaders::HeaderVector::iterator HttpRequestHeaders::FindHeader( - const gurl_base::StringPiece& key) { - for (auto it = headers_.begin(); it != headers_.end(); ++it) { - if (gurl_base::EqualsCaseInsensitiveASCII(key, it->key)) - return it; - } - - return headers_.end(); -} - -HttpRequestHeaders::HeaderVector::const_iterator HttpRequestHeaders::FindHeader( - const gurl_base::StringPiece& key) const { - for (auto it = headers_.begin(); it != headers_.end(); ++it) { - if (gurl_base::EqualsCaseInsensitiveASCII(key, it->key)) - return it; - } - - return headers_.end(); -} - -void HttpRequestHeaders::SetHeaderInternal(const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value) { - auto it = FindHeader(key); - if (it != headers_.end()) - it->value.assign(value.data(), value.size()); - else - headers_.push_back(HeaderKeyValuePair(key, value)); -} - -} // namespace net +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Copy from chromium/net/http_request_headers, and make some changes to compatible with quic_server_src +#include + +#include "base/strings/stringprintf.h" +#include "gquiche/quic/platform/api/quic_logging.h" +#include "googleurl/base/strings/string_split.h" +#include "googleurl/base/strings/string_util.h" +#include "net/http/http_util.h" +#include "net/http/http_request_headers.h" + +namespace net { + +const char HttpRequestHeaders::kConnectMethod[] = "CONNECT"; +const char HttpRequestHeaders::kGetMethod[] = "GET"; +const char HttpRequestHeaders::kHeadMethod[] = "HEAD"; +const char HttpRequestHeaders::kOptionsMethod[] = "OPTIONS"; +const char HttpRequestHeaders::kPostMethod[] = "POST"; +const char HttpRequestHeaders::kTraceMethod[] = "TRACE"; +const char HttpRequestHeaders::kTrackMethod[] = "TRACK"; +const char HttpRequestHeaders::kAccept[] = "Accept"; +const char HttpRequestHeaders::kAcceptCharset[] = "Accept-Charset"; +const char HttpRequestHeaders::kAcceptEncoding[] = "Accept-Encoding"; +const char HttpRequestHeaders::kAcceptLanguage[] = "Accept-Language"; +const char HttpRequestHeaders::kAuthorization[] = "Authorization"; +const char HttpRequestHeaders::kCacheControl[] = "Cache-Control"; +const char HttpRequestHeaders::kConnection[] = "Connection"; +const char HttpRequestHeaders::kContentLength[] = "Content-Length"; +const char HttpRequestHeaders::kContentType[] = "Content-Type"; +const char HttpRequestHeaders::kCookie[] = "Cookie"; +const char HttpRequestHeaders::kHost[] = "Host"; +const char HttpRequestHeaders::kIfMatch[] = "If-Match"; +const char HttpRequestHeaders::kIfModifiedSince[] = "If-Modified-Since"; +const char HttpRequestHeaders::kIfNoneMatch[] = "If-None-Match"; +const char HttpRequestHeaders::kIfRange[] = "If-Range"; +const char HttpRequestHeaders::kIfUnmodifiedSince[] = "If-Unmodified-Since"; +const char HttpRequestHeaders::kOrigin[] = "Origin"; +const char HttpRequestHeaders::kPragma[] = "Pragma"; +const char HttpRequestHeaders::kProxyAuthorization[] = "Proxy-Authorization"; +const char HttpRequestHeaders::kProxyConnection[] = "Proxy-Connection"; +const char HttpRequestHeaders::kRange[] = "Range"; +const char HttpRequestHeaders::kReferer[] = "Referer"; +const char HttpRequestHeaders::kTransferEncoding[] = "Transfer-Encoding"; +const char HttpRequestHeaders::kUserAgent[] = "User-Agent"; + +HttpRequestHeaders::HeaderKeyValuePair::HeaderKeyValuePair() = default; + +HttpRequestHeaders::HeaderKeyValuePair::HeaderKeyValuePair( + const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value) + : key(key.data(), key.size()), value(value.data(), value.size()) {} + +HttpRequestHeaders::Iterator::Iterator(const HttpRequestHeaders& headers) + : started_(false), + curr_(headers.headers_.begin()), + end_(headers.headers_.end()) {} + +HttpRequestHeaders::Iterator::~Iterator() = default; + +bool HttpRequestHeaders::Iterator::GetNext() { + if (!started_) { + started_ = true; + return curr_ != end_; + } + + if (curr_ == end_) + return false; + + ++curr_; + return curr_ != end_; +} + +HttpRequestHeaders::HttpRequestHeaders() = default; +HttpRequestHeaders::HttpRequestHeaders(const HttpRequestHeaders& other) = + default; +HttpRequestHeaders::HttpRequestHeaders(HttpRequestHeaders&& other) = default; +HttpRequestHeaders::~HttpRequestHeaders() = default; + +HttpRequestHeaders& HttpRequestHeaders::operator=( + const HttpRequestHeaders& other) = default; +HttpRequestHeaders& HttpRequestHeaders::operator=(HttpRequestHeaders&& other) = + default; + +bool HttpRequestHeaders::GetHeader(const gurl_base::StringPiece& key, + std::string* out) const { + auto it = FindHeader(key); + if (it == headers_.end()) + return false; + out->assign(it->value); + return true; +} + +void HttpRequestHeaders::Clear() { + headers_.clear(); +} + +void HttpRequestHeaders::SetHeader(const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value) { + // Invalid header names or values could mean clients can attach + // browser-internal headers. + QUICHE_DCHECK(HttpUtil::IsValidHeaderName(key)) << key; + QUICHE_DCHECK(HttpUtil::IsValidHeaderValue(value)) << key << ":" << value; + SetHeaderInternal(key, value); +} + +void HttpRequestHeaders::SetHeaderIfMissing(const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value) { + // Invalid header names or values could mean clients can attach + // browser-internal headers. + QUICHE_DCHECK(HttpUtil::IsValidHeaderName(key)); + QUICHE_DCHECK(HttpUtil::IsValidHeaderValue(value)); + auto it = FindHeader(key); + if (it == headers_.end()) + headers_.push_back(HeaderKeyValuePair(key, value)); +} + +void HttpRequestHeaders::RemoveHeader(const gurl_base::StringPiece& key) { + auto it = FindHeader(key); + if (it != headers_.end()) + headers_.erase(it); +} + +void HttpRequestHeaders::AddHeaderFromString( + const gurl_base::StringPiece& header_line) { + QUICHE_DCHECK_EQ(std::string::npos, header_line.find("\r\n")) + << "\"" << header_line << "\" contains CRLF."; + + const std::string::size_type key_end_index = header_line.find(":"); + if (key_end_index == std::string::npos) { + QUIC_LOG(DFATAL) << "\"" << header_line << "\" is missing colon delimiter."; + return; + } + + if (key_end_index == 0) { + QUIC_LOG(DFATAL) << "\"" << header_line << "\" is missing header key."; + return; + } + + const gurl_base::StringPiece header_key(header_line.data(), key_end_index); + if (!HttpUtil::IsValidHeaderName(header_key)) { + QUIC_LOG(DFATAL) << "\"" << header_line << "\" has invalid header key."; + return; + } + + const std::string::size_type value_index = key_end_index + 1; + + if (value_index < header_line.size()) { + gurl_base::StringPiece header_value(header_line.data() + value_index, + header_line.size() - value_index); + header_value = HttpUtil::TrimLWS(header_value); + if (!HttpUtil::IsValidHeaderValue(header_value)) { + QUIC_LOG(DFATAL) << "\"" << header_line << "\" has invalid header value."; + return; + } + SetHeader(header_key, header_value); + } else if (value_index == header_line.size()) { + SetHeader(header_key, ""); + } else { + QUIC_NOTREACHED(); + } +} + +void HttpRequestHeaders::AddHeadersFromString( + const gurl_base::StringPiece& headers) { + for (const gurl_base::StringPiece& header : gurl_base::SplitStringPieceUsingSubstr( + headers, "\r\n", gurl_base::TRIM_WHITESPACE, gurl_base::SPLIT_WANT_NONEMPTY)) { + AddHeaderFromString(header); + } +} + +void HttpRequestHeaders::MergeFrom(const HttpRequestHeaders& other) { + for (auto it = other.headers_.begin(); it != other.headers_.end(); ++it) { + SetHeader(it->key, it->value); + } +} + +std::string HttpRequestHeaders::ToString() const { + std::string output; + for (auto it = headers_.begin(); it != headers_.end(); ++it) { + base::StringAppendF(&output, "%s: %s\r\n", it->key.c_str(), + it->value.c_str()); + } + output.append("\r\n"); + return output; +} + +HttpRequestHeaders::HeaderVector::iterator HttpRequestHeaders::FindHeader( + const gurl_base::StringPiece& key) { + for (auto it = headers_.begin(); it != headers_.end(); ++it) { + if (gurl_base::EqualsCaseInsensitiveASCII(key, it->key)) + return it; + } + + return headers_.end(); +} + +HttpRequestHeaders::HeaderVector::const_iterator HttpRequestHeaders::FindHeader( + const gurl_base::StringPiece& key) const { + for (auto it = headers_.begin(); it != headers_.end(); ++it) { + if (gurl_base::EqualsCaseInsensitiveASCII(key, it->key)) + return it; + } + + return headers_.end(); +} + +void HttpRequestHeaders::SetHeaderInternal(const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value) { + auto it = FindHeader(key); + if (it != headers_.end()) + it->value.assign(value.data(), value.size()); + else + headers_.push_back(HeaderKeyValuePair(key, value)); +} + +} // namespace net diff --git a/net/http/http_request_headers.h b/net/http/http_request_headers.h index 297b07b0..016e7e4d 100644 --- a/net/http/http_request_headers.h +++ b/net/http/http_request_headers.h @@ -1,197 +1,197 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. -// Copy from chromium/net/http_request_headers, and make some changes to compatible with quic_server_src: -/* - 1. remove deps on net/base/net_export.h - 2. remove unsed net/log/net_log_capture_mode.h - 3. switch StringPiece to Quiche::QuicheStringPiece -*/ -// HttpRequestHeaders manages the request headers. -// It maintains these in a vector of header key/value pairs, thereby maintaining -// the order of the headers. This means that any lookups are linear time -// operations. - -#ifndef QUICHE_NET_HTTP_HTTP_REQUEST_HEADERS_H_ -#define QUICHE_NET_HTTP_HTTP_REQUEST_HEADERS_H_ - -#include -#include -#include - -#include "googleurl/base/macros.h" -#include "googleurl/base/strings/string_piece.h" - -namespace net { - -class HttpRequestHeaders { - public: - struct HeaderKeyValuePair { - HeaderKeyValuePair(); - HeaderKeyValuePair(const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value); - - std::string key; - std::string value; - }; - - typedef std::vector HeaderVector; - - class Iterator { - public: - explicit Iterator(const HttpRequestHeaders& headers); - ~Iterator(); - - // Advances the iterator to the next header, if any. Returns true if there - // is a next header. Use name() and value() methods to access the resultant - // header name and value. - bool GetNext(); - - // These two accessors are only valid if GetNext() returned true. - const std::string& name() const { return curr_->key; } - const std::string& value() const { return curr_->value; } - - private: - bool started_; - HttpRequestHeaders::HeaderVector::const_iterator curr_; - const HttpRequestHeaders::HeaderVector::const_iterator end_; - - DISALLOW_COPY_AND_ASSIGN(Iterator); - }; - - static const char kConnectMethod[]; - static const char kGetMethod[]; - static const char kHeadMethod[]; - static const char kOptionsMethod[]; - static const char kPostMethod[]; - static const char kTraceMethod[]; - static const char kTrackMethod[]; - - static const char kAccept[]; - static const char kAcceptCharset[]; - static const char kAcceptEncoding[]; - static const char kAcceptLanguage[]; - static const char kAuthorization[]; - static const char kCacheControl[]; - static const char kConnection[]; - static const char kContentType[]; - static const char kCookie[]; - static const char kContentLength[]; - static const char kHost[]; - static const char kIfMatch[]; - static const char kIfModifiedSince[]; - static const char kIfNoneMatch[]; - static const char kIfRange[]; - static const char kIfUnmodifiedSince[]; - static const char kOrigin[]; - static const char kPragma[]; - static const char kProxyAuthorization[]; - static const char kProxyConnection[]; - static const char kRange[]; - static const char kReferer[]; - static const char kTransferEncoding[]; - static const char kUserAgent[]; - - HttpRequestHeaders(); - HttpRequestHeaders(const HttpRequestHeaders& other); - HttpRequestHeaders(HttpRequestHeaders&& other); - ~HttpRequestHeaders(); - - HttpRequestHeaders& operator=(const HttpRequestHeaders& other); - HttpRequestHeaders& operator=(HttpRequestHeaders&& other); - - bool IsEmpty() const { return headers_.empty(); } - - bool HasHeader(const gurl_base::StringPiece& key) const { - return FindHeader(key) != headers_.end(); - } - - // Gets the first header that matches |key|. If found, returns true and - // writes the value to |out|. - bool GetHeader(const gurl_base::StringPiece& key, std::string* out) const; - - // Clears all the headers. - void Clear(); - - // Sets the header value pair for |key| and |value|. If |key| already exists, - // then the header value is modified, but the key is untouched, and the order - // in the vector remains the same. When comparing |key|, case is ignored. - // The caller must ensure that |key| passes HttpUtil::IsValidHeaderName() and - // |value| passes HttpUtil::IsValidHeaderValue(). - void SetHeader(const gurl_base::StringPiece& key, const gurl_base::StringPiece& value); - - // Does the same as above but without internal DCHECKs for validations. - void SetHeaderWithoutCheckForTesting(const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value) { - SetHeaderInternal(key, value); - } - - // Sets the header value pair for |key| and |value|, if |key| does not exist. - // If |key| already exists, the call is a no-op. - // When comparing |key|, case is ignored. - // - // The caller must ensure that |key| passes HttpUtil::IsValidHeaderName() and - // |value| passes HttpUtil::IsValidHeaderValue(). - void SetHeaderIfMissing(const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value); - - // Removes the first header that matches (case insensitive) |key|. - void RemoveHeader(const gurl_base::StringPiece& key); - - // Parses the header from a string and calls SetHeader() with it. This string - // should not contain any CRLF. As per RFC7230 Section 3.2, the format is: - // - // header-field = field-name ":" OWS field-value OWS - // - // field-name = token - // field-value = *( field-content / obs-fold ) - // field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] - // field-vchar = VCHAR / obs-text - // - // obs-fold = CRLF 1*( SP / HTAB ) - // ; obsolete line folding - // ; see Section 3.2.4 - // - // AddHeaderFromString() will trim any LWS surrounding the - // field-content. - void AddHeaderFromString(const gurl_base::StringPiece& header_line); - - // Same thing as AddHeaderFromString() except that |headers| is a "\r\n" - // delimited string of header lines. It will split up the string by "\r\n" - // and call AddHeaderFromString() on each. - void AddHeadersFromString(const gurl_base::StringPiece& headers); - - // Calls SetHeader() on each header from |other|, maintaining order. - void MergeFrom(const HttpRequestHeaders& other); - - // Copies from |other| to |this|. - void CopyFrom(const HttpRequestHeaders& other) { *this = other; } - - void Swap(HttpRequestHeaders* other) { headers_.swap(other->headers_); } - - // Serializes HttpRequestHeaders to a string representation. Joins all the - // header keys and values with ": ", and inserts "\r\n" between each header - // line, and adds the trailing "\r\n". - std::string ToString() const; - - const HeaderVector& GetHeaderVector() const { return headers_; } - - private: - HeaderVector::iterator FindHeader(const gurl_base::StringPiece& key); - HeaderVector::const_iterator FindHeader(const gurl_base::StringPiece& key) const; - - void SetHeaderInternal(const gurl_base::StringPiece& key, - const gurl_base::StringPiece& value); - - HeaderVector headers_; - - // Allow the copy construction and operator= to facilitate copying in - // HttpRequestHeaders. - // TODO(willchan): Investigate to see if we can remove the need to copy - // HttpRequestHeaders. - // DISALLOW_COPY_AND_ASSIGN(HttpRequestHeaders); -}; - -} // namespace net - -#endif // QUICHE_NET_HTTP_HTTP_REQUEST_HEADERS_H_ +// Copyright (c) 2012 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// Copy from chromium/net/http_request_headers, and make some changes to compatible with quic_server_src: +/* + 1. remove deps on net/base/net_export.h + 2. remove unsed net/log/net_log_capture_mode.h + 3. switch StringPiece to Quiche::QuicheStringPiece +*/ +// HttpRequestHeaders manages the request headers. +// It maintains these in a vector of header key/value pairs, thereby maintaining +// the order of the headers. This means that any lookups are linear time +// operations. + +#ifndef QUICHE_NET_HTTP_HTTP_REQUEST_HEADERS_H_ +#define QUICHE_NET_HTTP_HTTP_REQUEST_HEADERS_H_ + +#include +#include +#include + +#include "googleurl/base/macros.h" +#include "googleurl/base/strings/string_piece.h" + +namespace net { + +class HttpRequestHeaders { + public: + struct HeaderKeyValuePair { + HeaderKeyValuePair(); + HeaderKeyValuePair(const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value); + + std::string key; + std::string value; + }; + + typedef std::vector HeaderVector; + + class Iterator { + public: + explicit Iterator(const HttpRequestHeaders& headers); + ~Iterator(); + + // Advances the iterator to the next header, if any. Returns true if there + // is a next header. Use name() and value() methods to access the resultant + // header name and value. + bool GetNext(); + + // These two accessors are only valid if GetNext() returned true. + const std::string& name() const { return curr_->key; } + const std::string& value() const { return curr_->value; } + + private: + bool started_; + HttpRequestHeaders::HeaderVector::const_iterator curr_; + const HttpRequestHeaders::HeaderVector::const_iterator end_; + + DISALLOW_COPY_AND_ASSIGN(Iterator); + }; + + static const char kConnectMethod[]; + static const char kGetMethod[]; + static const char kHeadMethod[]; + static const char kOptionsMethod[]; + static const char kPostMethod[]; + static const char kTraceMethod[]; + static const char kTrackMethod[]; + + static const char kAccept[]; + static const char kAcceptCharset[]; + static const char kAcceptEncoding[]; + static const char kAcceptLanguage[]; + static const char kAuthorization[]; + static const char kCacheControl[]; + static const char kConnection[]; + static const char kContentType[]; + static const char kCookie[]; + static const char kContentLength[]; + static const char kHost[]; + static const char kIfMatch[]; + static const char kIfModifiedSince[]; + static const char kIfNoneMatch[]; + static const char kIfRange[]; + static const char kIfUnmodifiedSince[]; + static const char kOrigin[]; + static const char kPragma[]; + static const char kProxyAuthorization[]; + static const char kProxyConnection[]; + static const char kRange[]; + static const char kReferer[]; + static const char kTransferEncoding[]; + static const char kUserAgent[]; + + HttpRequestHeaders(); + HttpRequestHeaders(const HttpRequestHeaders& other); + HttpRequestHeaders(HttpRequestHeaders&& other); + ~HttpRequestHeaders(); + + HttpRequestHeaders& operator=(const HttpRequestHeaders& other); + HttpRequestHeaders& operator=(HttpRequestHeaders&& other); + + bool IsEmpty() const { return headers_.empty(); } + + bool HasHeader(const gurl_base::StringPiece& key) const { + return FindHeader(key) != headers_.end(); + } + + // Gets the first header that matches |key|. If found, returns true and + // writes the value to |out|. + bool GetHeader(const gurl_base::StringPiece& key, std::string* out) const; + + // Clears all the headers. + void Clear(); + + // Sets the header value pair for |key| and |value|. If |key| already exists, + // then the header value is modified, but the key is untouched, and the order + // in the vector remains the same. When comparing |key|, case is ignored. + // The caller must ensure that |key| passes HttpUtil::IsValidHeaderName() and + // |value| passes HttpUtil::IsValidHeaderValue(). + void SetHeader(const gurl_base::StringPiece& key, const gurl_base::StringPiece& value); + + // Does the same as above but without internal DCHECKs for validations. + void SetHeaderWithoutCheckForTesting(const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value) { + SetHeaderInternal(key, value); + } + + // Sets the header value pair for |key| and |value|, if |key| does not exist. + // If |key| already exists, the call is a no-op. + // When comparing |key|, case is ignored. + // + // The caller must ensure that |key| passes HttpUtil::IsValidHeaderName() and + // |value| passes HttpUtil::IsValidHeaderValue(). + void SetHeaderIfMissing(const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value); + + // Removes the first header that matches (case insensitive) |key|. + void RemoveHeader(const gurl_base::StringPiece& key); + + // Parses the header from a string and calls SetHeader() with it. This string + // should not contain any CRLF. As per RFC7230 Section 3.2, the format is: + // + // header-field = field-name ":" OWS field-value OWS + // + // field-name = token + // field-value = *( field-content / obs-fold ) + // field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] + // field-vchar = VCHAR / obs-text + // + // obs-fold = CRLF 1*( SP / HTAB ) + // ; obsolete line folding + // ; see Section 3.2.4 + // + // AddHeaderFromString() will trim any LWS surrounding the + // field-content. + void AddHeaderFromString(const gurl_base::StringPiece& header_line); + + // Same thing as AddHeaderFromString() except that |headers| is a "\r\n" + // delimited string of header lines. It will split up the string by "\r\n" + // and call AddHeaderFromString() on each. + void AddHeadersFromString(const gurl_base::StringPiece& headers); + + // Calls SetHeader() on each header from |other|, maintaining order. + void MergeFrom(const HttpRequestHeaders& other); + + // Copies from |other| to |this|. + void CopyFrom(const HttpRequestHeaders& other) { *this = other; } + + void Swap(HttpRequestHeaders* other) { headers_.swap(other->headers_); } + + // Serializes HttpRequestHeaders to a string representation. Joins all the + // header keys and values with ": ", and inserts "\r\n" between each header + // line, and adds the trailing "\r\n". + std::string ToString() const; + + const HeaderVector& GetHeaderVector() const { return headers_; } + + private: + HeaderVector::iterator FindHeader(const gurl_base::StringPiece& key); + HeaderVector::const_iterator FindHeader(const gurl_base::StringPiece& key) const; + + void SetHeaderInternal(const gurl_base::StringPiece& key, + const gurl_base::StringPiece& value); + + HeaderVector headers_; + + // Allow the copy construction and operator= to facilitate copying in + // HttpRequestHeaders. + // TODO(willchan): Investigate to see if we can remove the need to copy + // HttpRequestHeaders. + // DISALLOW_COPY_AND_ASSIGN(HttpRequestHeaders); +}; + +} // namespace net + +#endif // QUICHE_NET_HTTP_HTTP_REQUEST_HEADERS_H_ diff --git a/net/http/http_util.cc b/net/http/http_util.cc index 65c1e64a..4a790da2 100644 --- a/net/http/http_util.cc +++ b/net/http/http_util.cc @@ -1,75 +1,75 @@ -// http_util is derived from chromium net/http/http_util -#include "net/http/http_util.h" - -namespace net { - -namespace { - -template -void TrimLWSImplementation(ConstIterator* begin, ConstIterator* end) { - // leading whitespace - while (*begin < *end && HttpUtil::IsLWS((*begin)[0])) - ++(*begin); - - // trailing whitespace - while (*begin < *end && HttpUtil::IsLWS((*end)[-1])) - --(*end); -} - -} //namespace - - -// static -bool HttpUtil::IsValidHeaderName(gurl_base::StringPiece name) { - // Check whether the header name is RFC 2616-compliant. - return HttpUtil::IsToken(name); -} - -// static -bool HttpUtil::IsValidHeaderValue(gurl_base::StringPiece value) { - // Just a sanity check: disallow NUL, CR and LF. - for (char c : value) { - if (c == '\0' || c == '\r' || c == '\n') - return false; - } - return true; -} - -bool HttpUtil::IsLWS(char c) { - const gurl_base::StringPiece kWhiteSpaceCharacters(HTTP_LWS); - return kWhiteSpaceCharacters.find(c) != gurl_base::StringPiece::npos; -} - -// static -void HttpUtil::TrimLWS(std::string::const_iterator* begin, - std::string::const_iterator* end) { - TrimLWSImplementation(begin, end); -} - -// static -gurl_base::StringPiece HttpUtil::TrimLWS(const gurl_base::StringPiece& string) { - const char* begin = string.data(); - const char* end = string.data() + string.size(); - TrimLWSImplementation(&begin, &end); - return gurl_base::StringPiece(begin, end - begin); -} - -// See RFC 7230 Sec 3.2.6 for the definition of |token|. -bool HttpUtil::IsToken(gurl_base::StringPiece string) { - if (string.empty()) - return false; - for (char c : string) { - if (!IsTokenChar(c)) - return false; - } - return true; -} - -bool HttpUtil::IsTokenChar(char c) { - return !(c >= 0x7F || c <= 0x20 || c == '(' || c == ')' || c == '<' || - c == '>' || c == '@' || c == ',' || c == ';' || c == ':' || - c == '\\' || c == '"' || c == '/' || c == '[' || c == ']' || - c == '?' || c == '=' || c == '{' || c == '}'); -} - -} // namespace net +// http_util is derived from chromium net/http/http_util +#include "net/http/http_util.h" + +namespace net { + +namespace { + +template +void TrimLWSImplementation(ConstIterator* begin, ConstIterator* end) { + // leading whitespace + while (*begin < *end && HttpUtil::IsLWS((*begin)[0])) + ++(*begin); + + // trailing whitespace + while (*begin < *end && HttpUtil::IsLWS((*end)[-1])) + --(*end); +} + +} //namespace + + +// static +bool HttpUtil::IsValidHeaderName(gurl_base::StringPiece name) { + // Check whether the header name is RFC 2616-compliant. + return HttpUtil::IsToken(name); +} + +// static +bool HttpUtil::IsValidHeaderValue(gurl_base::StringPiece value) { + // Just a sanity check: disallow NUL, CR and LF. + for (char c : value) { + if (c == '\0' || c == '\r' || c == '\n') + return false; + } + return true; +} + +bool HttpUtil::IsLWS(char c) { + const gurl_base::StringPiece kWhiteSpaceCharacters(HTTP_LWS); + return kWhiteSpaceCharacters.find(c) != gurl_base::StringPiece::npos; +} + +// static +void HttpUtil::TrimLWS(std::string::const_iterator* begin, + std::string::const_iterator* end) { + TrimLWSImplementation(begin, end); +} + +// static +gurl_base::StringPiece HttpUtil::TrimLWS(const gurl_base::StringPiece& string) { + const char* begin = string.data(); + const char* end = string.data() + string.size(); + TrimLWSImplementation(&begin, &end); + return gurl_base::StringPiece(begin, end - begin); +} + +// See RFC 7230 Sec 3.2.6 for the definition of |token|. +bool HttpUtil::IsToken(gurl_base::StringPiece string) { + if (string.empty()) + return false; + for (char c : string) { + if (!IsTokenChar(c)) + return false; + } + return true; +} + +bool HttpUtil::IsTokenChar(char c) { + return !(c >= 0x7F || c <= 0x20 || c == '(' || c == ')' || c == '<' || + c == '>' || c == '@' || c == ',' || c == ';' || c == ':' || + c == '\\' || c == '"' || c == '/' || c == '[' || c == ']' || + c == '?' || c == '=' || c == '{' || c == '}'); +} + +} // namespace net diff --git a/net/http/http_util.h b/net/http/http_util.h index 2603ec95..9d75b8de 100644 --- a/net/http/http_util.h +++ b/net/http/http_util.h @@ -1,41 +1,41 @@ -// http_util is derived from chromium/net/http/http_util -#ifndef QUICHE_NET_HTTP_HTTP_UTIL_H_ -#define QUICHE_NET_HTTP_HTTP_UTIL_H_ - -#include -#include "googleurl/base/strings/string_piece.h" - -// This is a macro to support extending this string literal at compile time. -// Please excuse me polluting your global namespace! -#define HTTP_LWS " \t" - -namespace net { - -class HttpUtil { - public: - // Returns true if |name| is a valid HTTP header name. - static bool IsValidHeaderName(gurl_base::StringPiece name); - - // Returns false if |value| contains NUL or CRLF. This method does not perform - // a fully RFC-2616-compliant header value validation. - static bool IsValidHeaderValue(gurl_base::StringPiece value); - - // Return true if the character is HTTP "linear white space" (SP | HT). - // This definition corresponds with the HTTP_LWS macro, and does not match - // newlines. - static bool IsLWS(char c); - - // Trim HTTP_LWS chars from the beginning and end of the string. - static void TrimLWS(std::string::const_iterator* begin, - std::string::const_iterator* end); - static gurl_base::StringPiece TrimLWS(const gurl_base::StringPiece& string); - - // Whether the character is a valid |tchar| as defined in RFC 7230 Sec 3.2.6. - static bool IsTokenChar(char c); - // Whether the string is a valid |token| as defined in RFC 7230 Sec 3.2.6. - static bool IsToken(gurl_base::StringPiece str); -}; - -} // namespace net - -#endif // QUICHE_NET_HTTP_HTTP_UTIL_H_ +// http_util is derived from chromium/net/http/http_util +#ifndef QUICHE_NET_HTTP_HTTP_UTIL_H_ +#define QUICHE_NET_HTTP_HTTP_UTIL_H_ + +#include +#include "googleurl/base/strings/string_piece.h" + +// This is a macro to support extending this string literal at compile time. +// Please excuse me polluting your global namespace! +#define HTTP_LWS " \t" + +namespace net { + +class HttpUtil { + public: + // Returns true if |name| is a valid HTTP header name. + static bool IsValidHeaderName(gurl_base::StringPiece name); + + // Returns false if |value| contains NUL or CRLF. This method does not perform + // a fully RFC-2616-compliant header value validation. + static bool IsValidHeaderValue(gurl_base::StringPiece value); + + // Return true if the character is HTTP "linear white space" (SP | HT). + // This definition corresponds with the HTTP_LWS macro, and does not match + // newlines. + static bool IsLWS(char c); + + // Trim HTTP_LWS chars from the beginning and end of the string. + static void TrimLWS(std::string::const_iterator* begin, + std::string::const_iterator* end); + static gurl_base::StringPiece TrimLWS(const gurl_base::StringPiece& string); + + // Whether the character is a valid |tchar| as defined in RFC 7230 Sec 3.2.6. + static bool IsTokenChar(char c); + // Whether the string is a valid |token| as defined in RFC 7230 Sec 3.2.6. + static bool IsToken(gurl_base::StringPiece str); +}; + +} // namespace net + +#endif // QUICHE_NET_HTTP_HTTP_UTIL_H_ diff --git a/platform/epoll_platform_impl/quic_epoll_clock.cc b/platform/epoll_platform_impl/quic_epoll_clock.cc index 3a14283e..2d3c8490 100644 --- a/platform/epoll_platform_impl/quic_epoll_clock.cc +++ b/platform/epoll_platform_impl/quic_epoll_clock.cc @@ -1,38 +1,38 @@ -// NOLINT(namespace-quic) -// -// This file is part of the QUICHE platform implementation, and is not to be -// consumed or referenced directly by other Envoy code. It serves purely as a -// porting layer for QUICHE. - -#include "platform/epoll_platform_impl/quic_epoll_clock.h" - -namespace quic { - -QuicEpollClock::QuicEpollClock(epoll_server::SimpleEpollServer* epoll_server) - : epoll_server_(epoll_server), largest_time_(QuicTime::Zero()) {} - -QuicTime QuicEpollClock::ApproximateNow() const { - return CreateTimeFromMicroseconds(epoll_server_->ApproximateNowInUsec()); -} - -QuicTime QuicEpollClock::Now() const { - QuicTime now = CreateTimeFromMicroseconds(epoll_server_->NowInUsec()); - - if (now <= largest_time_) { - // Time not increasing, return |largest_time_|. - return largest_time_; - } - - largest_time_ = now; - return largest_time_; -} - -QuicWallTime QuicEpollClock::WallNow() const { - return QuicWallTime::FromUNIXMicroseconds(epoll_server_->ApproximateNowInUsec()); -} - -QuicTime QuicEpollClock::ConvertWallTimeToQuicTime(const QuicWallTime& walltime) const { - return QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(walltime.ToUNIXMicroseconds()); -} - -} // namespace quic +// NOLINT(namespace-quic) +// +// This file is part of the QUICHE platform implementation, and is not to be +// consumed or referenced directly by other Envoy code. It serves purely as a +// porting layer for QUICHE. + +#include "platform/epoll_platform_impl/quic_epoll_clock.h" + +namespace quic { + +QuicEpollClock::QuicEpollClock(epoll_server::SimpleEpollServer* epoll_server) + : epoll_server_(epoll_server), largest_time_(QuicTime::Zero()) {} + +QuicTime QuicEpollClock::ApproximateNow() const { + return CreateTimeFromMicroseconds(epoll_server_->ApproximateNowInUsec()); +} + +QuicTime QuicEpollClock::Now() const { + QuicTime now = CreateTimeFromMicroseconds(epoll_server_->NowInUsec()); + + if (now <= largest_time_) { + // Time not increasing, return |largest_time_|. + return largest_time_; + } + + largest_time_ = now; + return largest_time_; +} + +QuicWallTime QuicEpollClock::WallNow() const { + return QuicWallTime::FromUNIXMicroseconds(epoll_server_->ApproximateNowInUsec()); +} + +QuicTime QuicEpollClock::ConvertWallTimeToQuicTime(const QuicWallTime& walltime) const { + return QuicTime::Zero() + QuicTime::Delta::FromMicroseconds(walltime.ToUNIXMicroseconds()); +} + +} // namespace quic diff --git a/platform/http2_platform_impl/http2_flag_utils_impl.h b/platform/http2_platform_impl/http2_flag_utils_impl.h index f49677f9..5130c5ca 100644 --- a/platform/http2_platform_impl/http2_flag_utils_impl.h +++ b/platform/http2_platform_impl/http2_flag_utils_impl.h @@ -3,7 +3,7 @@ #include "gquiche/common/platform/api/quiche_flag_utils.h" #define HTTP2_RELOADABLE_FLAG_COUNT_IMPL -#define HTTP2_RELOADABLE_FLAG_COUNT_N_IMPL +#define HTTP2_RELOADABLE_FLAG_COUNT_N_IMPL(x,y,z) #define HTTP2_RESTART_FLAG_COUNT_IMPL #define HTTP2_RESTART_FLAG_COUNT_N_IMPL diff --git a/platform/quic_platform_impl/quic_cert_utils_impl.cc b/platform/quic_platform_impl/quic_cert_utils_impl.cc index 3208365c..4468796d 100644 --- a/platform/quic_platform_impl/quic_cert_utils_impl.cc +++ b/platform/quic_platform_impl/quic_cert_utils_impl.cc @@ -1,68 +1,68 @@ -// NOLINT(namespace-quic) - -// This file is part of the QUICHE platform implementation, and is not to be -// consumed or referenced directly by other Envoy code. It serves purely as a -// porting layer for QUICHE. - -#include "platform/quic_platform_impl/quic_cert_utils_impl.h" - -#include "openssl/bytestring.h" - -namespace quic { - -// static -bool QuicCertUtilsImpl::ExtractSubjectNameFromDERCert(quiche::QuicheStringPiece cert, - quiche::QuicheStringPiece* subject_out) { - CBS tbs_certificate; - if (!SeekToSubject(cert, &tbs_certificate)) { - return false; - } - - CBS subject; - if (!CBS_get_asn1_element(&tbs_certificate, &subject, CBS_ASN1_SEQUENCE)) { - return false; - } - *subject_out = - absl::string_view(reinterpret_cast(CBS_data(&subject)), CBS_len(&subject)); - return true; -} - -// static -bool QuicCertUtilsImpl::SeekToSubject(quiche::QuicheStringPiece cert, CBS* tbs_certificate) { - CBS der; - CBS_init(&der, reinterpret_cast(cert.data()), cert.size()); - CBS certificate; - // From RFC 5280, section 4.1 - // Certificate ::= SEQUENCE { - // tbsCertificate TBSCertificate, - // signatureAlgorithm AlgorithmIdentifier, - // signatureValue BIT STRING } - - // TBSCertificate ::= SEQUENCE { - // version [0] EXPLICIT Version DEFAULT v1, - // serialNumber CertificateSerialNumber, - // signature AlgorithmIdentifier, - // issuer Name, - // validity Validity, - // subject Name, - // subjectPublicKeyInfo SubjectPublicKeyInfo, - if (!CBS_get_asn1(&der, &certificate, CBS_ASN1_SEQUENCE) || - CBS_len(&der) != 0 || // We don't allow junk after the certificate. - !CBS_get_asn1(&certificate, tbs_certificate, CBS_ASN1_SEQUENCE) || - // version. - !CBS_get_optional_asn1(tbs_certificate, nullptr, nullptr, - CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 0) || - // Serial number. - !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_INTEGER) || - // Signature. - !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_SEQUENCE) || - // Issuer. - !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_SEQUENCE) || - // Validity. - !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_SEQUENCE)) { - return false; - } - return true; -} - -} // namespace quic +// NOLINT(namespace-quic) + +// This file is part of the QUICHE platform implementation, and is not to be +// consumed or referenced directly by other Envoy code. It serves purely as a +// porting layer for QUICHE. + +#include "platform/quic_platform_impl/quic_cert_utils_impl.h" + +#include "openssl/bytestring.h" + +namespace quic { + +// static +bool QuicCertUtilsImpl::ExtractSubjectNameFromDERCert(quiche::QuicheStringPiece cert, + quiche::QuicheStringPiece* subject_out) { + CBS tbs_certificate; + if (!SeekToSubject(cert, &tbs_certificate)) { + return false; + } + + CBS subject; + if (!CBS_get_asn1_element(&tbs_certificate, &subject, CBS_ASN1_SEQUENCE)) { + return false; + } + *subject_out = + absl::string_view(reinterpret_cast(CBS_data(&subject)), CBS_len(&subject)); + return true; +} + +// static +bool QuicCertUtilsImpl::SeekToSubject(quiche::QuicheStringPiece cert, CBS* tbs_certificate) { + CBS der; + CBS_init(&der, reinterpret_cast(cert.data()), cert.size()); + CBS certificate; + // From RFC 5280, section 4.1 + // Certificate ::= SEQUENCE { + // tbsCertificate TBSCertificate, + // signatureAlgorithm AlgorithmIdentifier, + // signatureValue BIT STRING } + + // TBSCertificate ::= SEQUENCE { + // version [0] EXPLICIT Version DEFAULT v1, + // serialNumber CertificateSerialNumber, + // signature AlgorithmIdentifier, + // issuer Name, + // validity Validity, + // subject Name, + // subjectPublicKeyInfo SubjectPublicKeyInfo, + if (!CBS_get_asn1(&der, &certificate, CBS_ASN1_SEQUENCE) || + CBS_len(&der) != 0 || // We don't allow junk after the certificate. + !CBS_get_asn1(&certificate, tbs_certificate, CBS_ASN1_SEQUENCE) || + // version. + !CBS_get_optional_asn1(tbs_certificate, nullptr, nullptr, + CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 0) || + // Serial number. + !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_INTEGER) || + // Signature. + !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_SEQUENCE) || + // Issuer. + !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_SEQUENCE) || + // Validity. + !CBS_get_asn1(tbs_certificate, nullptr, CBS_ASN1_SEQUENCE)) { + return false; + } + return true; +} + +} // namespace quic diff --git a/platform/quic_platform_impl/quic_file_utils_impl.cc b/platform/quic_platform_impl/quic_file_utils_impl.cc index 32998adf..503d5f37 100644 --- a/platform/quic_platform_impl/quic_file_utils_impl.cc +++ b/platform/quic_platform_impl/quic_file_utils_impl.cc @@ -7,66 +7,11 @@ #include "platform/quic_platform_impl/quic_file_utils_impl.h" #include "absl/strings/str_cat.h" -#include -#include -#include -#include namespace quic { namespace { -static void traverseFilesInDirectory(const std::string& dirname, std::vector& files) { - DIR *dp = opendir(dirname.c_str()); - if (dp == nullptr) { - return; - } - - struct dirent *dirp = nullptr; - while((dirp = readdir(dp)) != nullptr) { - if (strcmp(".", dirp->d_name) == 0 || strcmp("..", dirp->d_name) == 0) { - continue; - } - - struct stat statbuf; - std::string fp = dirname + (dirname[dirname.length() -1 ] == '/' ? "" : "/") + std::string(dirp->d_name); - if(stat(fp.c_str(), &statbuf) == -1) { - continue; - } - if(S_ISDIR(statbuf.st_mode)) { - continue; - } - - files.push_back(std::move(fp)); - } - closedir(dp); -} - void depthFirstTraverseDirectory(const std::string& dirname, std::vector& files) { - DIR *dp = opendir(dirname.c_str()); - if (dp == nullptr) { - return; - } - - struct dirent *dirp = nullptr; - while((dirp = readdir(dp)) != nullptr) { - if (strcmp(".", dirp->d_name) == 0 || strcmp("..", dirp->d_name) == 0) { - continue; - } - - struct stat statbuf; - std::string fp = dirname + (dirname[dirname.length() -1 ] == '/' ? "" : "/") + std::string(dirp->d_name); - if(stat(fp.c_str(), &statbuf) == -1) { - continue; - } - if(S_ISREG(statbuf.st_mode)) { - continue; - } - - traverseFilesInDirectory(fp, files); - } - closedir(dp); - - return; } } // namespace @@ -80,11 +25,6 @@ std::vector ReadFileContentsImpl(const std::string& dirname) { // Reads the contents of |filename| as a string into |contents|. void ReadFileContentsImpl(quiche::QuicheStringPiece filename, std::string* contents) { - std::ifstream ifs(std::string(filename.data(), filename.length())); - ifs.seekg(0, std::ios::end); - contents->reserve(ifs.tellg()); - ifs.seekg(0, std::ios::beg); - contents->assign((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); } } // namespace quic diff --git a/platform/quiche_platform_impl/quiche_logging_impl.h b/platform/quiche_platform_impl/quiche_logging_impl.h index 7e29ec8a..89056aa9 100644 --- a/platform/quiche_platform_impl/quiche_logging_impl.h +++ b/platform/quiche_platform_impl/quiche_logging_impl.h @@ -111,6 +111,7 @@ #define QUICHE_DCHECK_EQ_IMPL(a, b) QUICHE_DCHECK_IMPL((a) == (b)) #define QUICHE_PREDICT_FALSE_IMPL(x) ABSL_PREDICT_FALSE(x) +#define QUICHE_PREDICT_TRUE_IMPL(x) (x) namespace quiche { diff --git a/third_party/abseil-cpp b/third_party/abseil-cpp index bd0de71e..9336be04 160000 --- a/third_party/abseil-cpp +++ b/third_party/abseil-cpp @@ -1 +1 @@ -Subproject commit bd0de71e754eb3280094e89c7ac35a14dac6d61c +Subproject commit 9336be04a242237cd41a525bedfcf3be1bb55377 diff --git a/third_party/boringssl b/third_party/boringssl index 15961379..3a667d10 160000 --- a/third_party/boringssl +++ b/third_party/boringssl @@ -1 +1 @@ -Subproject commit 15961379e6b2682d73c3cb8f8016a09d04257c77 +Subproject commit 3a667d10e94186fd503966f5638e134fe9fb4080 diff --git a/third_party/rapidjson b/third_party/rapidjson new file mode 160000 index 00000000..fcb23c2d --- /dev/null +++ b/third_party/rapidjson @@ -0,0 +1 @@ +Subproject commit fcb23c2dbf561ec0798529be4f66394d3e4996d8 diff --git a/utils/google_quiche_rewrite.sh b/utils/google_quiche_rewrite.sh index a7525f41..3bc0e28c 100644 --- a/utils/google_quiche_rewrite.sh +++ b/utils/google_quiche_rewrite.sh @@ -23,26 +23,26 @@ cat <sed_commands # TODO # Rewrite include directives for gquiche root dir -/^#include/ s!common/!gquiche/common/! -/^#include/ s!epoll_server/!gquiche/epoll_server/! -/^#include/ s!http2/!gquiche/http2/! -/^#include/ s!quic/!gquiche/quic/! -/^#include/ s!spdy/!gquiche/spdy/! +/^#include/ s!"common/!"gquiche/common/! +/^#include/ s!"epoll_server/!"gquiche/epoll_server/! +/^#include/ s!"http2/!"gquiche/http2/! +/^#include/ s!"quic/!"gquiche/quic/! +/^#include/ s!"spdy/!"gquiche/spdy/! # Rewrite include directives for platform impl files. -/^#include/ s!net/quiche/common/platform/impl/!platform/quiche_platform_impl/! -/^#include/ s!quiche_platform_impl/!platform/quiche_platform_impl/! -/^#include/ s!net/tools/epoll_server/platform/impl/!platform/epoll_platform_impl/! -/^#include/ s!net/http2/platform/impl/!platform/http2_platform_impl/! -/^#include/ s!net/quic/platform/impl/!platform/quic_platform_impl/! -/^#include/ s!net/spdy/platform/impl/!platform/spdy_platform_impl/! +/^#include/ s!"net/quiche/common/platform/impl/!"platform/quiche_platform_impl/! +/^#include/ s!"quiche_platform_impl/!"platform/quiche_platform_impl/! +/^#include/ s!"net/tools/epoll_server/platform/impl/!"platform/epoll_platform_impl/! +/^#include/ s!"net/http2/platform/impl/!"platform/http2_platform_impl/! +/^#include/ s!"net/quic/platform/impl/!"platform/quic_platform_impl/! +/^#include/ s!"net/spdy/platform/impl/!"platform/spdy_platform_impl/! # Rewrite gmock & gtest includes. # TODO # Rewrite third_party includes. -/^#include/ s!third_party/boringssl/src/include/!! -/^#include/ s!third_party/zlib/zlib!zlib! +/^#include/ s!"third_party/boringssl/src/include/!"! +/^#include/ s!"third_party/zlib/zlib!"zlib! # Rewrite #pragma clang /^#pragma/ s!clang!GCC!