From 3d3312020816041d783cb39caeabf8acbc7ed44b Mon Sep 17 00:00:00 2001 From: sbiscigl Date: Mon, 7 Oct 2024 16:17:06 -0400 Subject: [PATCH] wip - everything in place, integration tests passing --- .../include/aws/core/client/AWSClient.h | 1 - .../include/aws/core/http/ChecksumContext.h | 33 ------ .../include/aws/core/http/HttpRequest.h | 14 ++- .../interceptor/impl/ChecksumInterceptor.h | 101 +++++++++++++++++- .../source/client/AWSClient.cpp | 17 ++- 5 files changed, 116 insertions(+), 50 deletions(-) delete mode 100644 src/aws-cpp-sdk-core/include/aws/core/http/ChecksumContext.h diff --git a/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h b/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h index 9d050df6a6b..749ac31c8ae 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h @@ -342,7 +342,6 @@ namespace Aws void AddContentBodyToRequest(const std::shared_ptr& httpRequest, const std::shared_ptr& body, bool needsContentMd5 = false, bool isChunked = false) const; void AddCommonHeaders(Aws::Http::HttpRequest& httpRequest) const; - std::shared_ptr GetBodyStream(const Aws::AmazonWebServiceRequest& request) const; void AppendHeaderValueToRequest(const std::shared_ptr &request, String header, String value) const; std::shared_ptr m_httpClient; diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/ChecksumContext.h b/src/aws-cpp-sdk-core/include/aws/core/http/ChecksumContext.h deleted file mode 100644 index ce686c1970b..00000000000 --- a/src/aws-cpp-sdk-core/include/aws/core/http/ChecksumContext.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once -#include "aws/core/utils/memory/stl/AWSString.h" -#include "aws/core/utils/memory/stl/AWSVector.h" - -namespace Aws -{ - namespace Http - { - class ChecksumContext final - { - public: - ChecksumContext(bool responseChecksumEnabled, const Aws::Vector& responseChecksums) - : m_responseChecksumEnabled(responseChecksumEnabled), - m_responseChecksums(responseChecksums) - { - } - - bool ResponseChecksumEnabled() const - { - return m_responseChecksumEnabled; - } - - Aws::Vector ResponseChecksums() const - { - return m_responseChecksums; - } - - private: - bool m_responseChecksumEnabled{}; - Aws::Vector m_responseChecksums{}; - }; - } -} diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h index dc05bcd78ed..5bb3b54c5ec 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -93,6 +92,15 @@ namespace Aws Aws::Map parameterMap; }; + struct ChecksumContext + { + bool responseChecksumEnabled{}; + bool canAwsChunkRequest{}; + Aws::String requestChecksumAlgorithmName{}; + Aws::Vector responseChecksums{}; + std::shared_ptr serviceSpecificParameters{nullptr}; + }; + /** * Abstract class for representing an HttpRequest. */ @@ -589,7 +597,7 @@ namespace Aws inline std::shared_ptr GetServiceSpecificParameters() { return m_serviceSpecificParameters; } - inline ChecksumContext GetChecksumContext() const { return m_checksumContext;} + ChecksumContext GetChecksumContext() const { return m_checksumContext;} inline void SetChecksumContext(const ChecksumContext& m_checksum_context) { m_checksumContext = m_checksum_context; } @@ -608,7 +616,7 @@ namespace Aws std::pair> m_requestHash; Aws::Vector>> m_responseValidationHashes; std::shared_ptr m_serviceSpecificParameters; - ChecksumContext m_checksumContext{false, {}}; + ChecksumContext m_checksumContext; }; } // namespace Http diff --git a/src/aws-cpp-sdk-core/include/smithy/interceptor/impl/ChecksumInterceptor.h b/src/aws-cpp-sdk-core/include/smithy/interceptor/impl/ChecksumInterceptor.h index 4e0aa896f84..c8adc9c7e68 100644 --- a/src/aws-cpp-sdk-core/include/smithy/interceptor/impl/ChecksumInterceptor.h +++ b/src/aws-cpp-sdk-core/include/smithy/interceptor/impl/ChecksumInterceptor.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include #include @@ -19,6 +20,7 @@ namespace smithy class ChecksumInterceptor final : public Interceptor { public: + ChecksumInterceptor() = default; ChecksumInterceptor(const ChecksumInterceptor& other) = delete; ChecksumInterceptor(ChecksumInterceptor&& other) noexcept = default; @@ -29,10 +31,103 @@ namespace smithy ModifyRequestOutcome ModifyRequest(InterceptorContext& context) override { auto httpRequest = context.GetRequest().GetResult(); + const auto checksumContext = httpRequest->GetChecksumContext(); + const auto serviceSpecificParams = httpRequest->GetChecksumContext().serviceSpecificParameters; + // Request checksums + Aws::String checksumAlgorithmName = Aws::Utils::StringUtils::ToLower(checksumContext.requestChecksumAlgorithmName.c_str()); + if (serviceSpecificParams) { + auto requestChecksumOverride = serviceSpecificParams->parameterMap.find("overrideChecksum"); + if (requestChecksumOverride != serviceSpecificParams->parameterMap.end()) { + checksumAlgorithmName = requestChecksumOverride->second; + } + } + + bool shouldSkipChecksum = serviceSpecificParams && + serviceSpecificParams->parameterMap.find("overrideChecksumDisable") != + serviceSpecificParams->parameterMap.end(); + + //Check if user has provided the checksum algorithm + if (!checksumAlgorithmName.empty() && !shouldSkipChecksum) + { + // Check if user has provided a checksum value for the specified algorithm + const Aws::String checksumType = "x-amz-checksum-" + checksumAlgorithmName; + const Aws::Http::HeaderValueCollection &headers = httpRequest->GetHeaders(); + const auto checksumHeader = headers.find(checksumType); + bool checksumValueAndAlgorithmProvided = checksumHeader != headers.end(); + + // For non-streaming payload, the resolved checksum location is always header. + // For streaming payload, the resolved checksum location depends on whether it is an unsigned payload, we let AwsAuthSigner decide it. + if (checksumContext.canAwsChunkRequest && checksumValueAndAlgorithmProvided) + { + const auto hash = Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG, checksumHeader->second); + httpRequest->SetRequestHash(checksumAlgorithmName,hash); + } + else if (checksumValueAndAlgorithmProvided){ + httpRequest->SetHeaderValue(checksumType, checksumHeader->second); + } + else if (checksumAlgorithmName == "crc32") + { + if (checksumContext.canAwsChunkRequest) + { + httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG)); + } + else + { + const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG, ""); + httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateCRC32(*body))); + } + } + else if (checksumAlgorithmName == "crc32c") + { + if (checksumContext.canAwsChunkRequest) + { + httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG)); + } + else + { + const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG, ""); + httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateCRC32C(*body))); + } + } + else if (checksumAlgorithmName == "sha256") + { + if (checksumContext.canAwsChunkRequest) + { + httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG)); + } + else + { + const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG, ""); + httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateSHA256(*body))); + } + } + else if (checksumAlgorithmName == "sha1") + { + if (checksumContext.canAwsChunkRequest) + { + httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG)); + } + else + { + const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG, ""); + httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateSHA1(*body))); + } + } + else if (checksumAlgorithmName == "md5" && headers.find(Aws::Http::CONTENT_MD5_HEADER) == headers.end()) + { + const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared(CHECKSUM_INTERCEPTOR_LOG_TAG, ""); + httpRequest->SetHeaderValue(Aws::Http::CONTENT_MD5_HEADER, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateMD5(*body))); + } + else if (headers.find(Aws::Http::CONTENT_MD5_HEADER) == headers.end()) + { + AWS_LOGSTREAM_WARN(CHECKSUM_INTERCEPTOR_LOG_TAG, "Checksum algorithm: " << checksumAlgorithmName << " is not supported by SDK."); + } + } + // Response checksums - if (httpRequest->GetChecksumContext().ResponseChecksumEnabled()) + if (checksumContext.responseChecksumEnabled) { - for (const Aws::String& responseChecksumAlgorithmName : httpRequest->GetChecksumContext().ResponseChecksums()) + for (const Aws::String& responseChecksumAlgorithmName : checksumContext.responseChecksums) { const auto lowered = Aws::Utils::StringUtils::ToLower(responseChecksumAlgorithmName.c_str()); @@ -67,7 +162,7 @@ namespace smithy ModifyResponseOutcome ModifyResponse(InterceptorContext& context) override { - if (context.GetRequest().GetResult()->GetChecksumContext().ResponseChecksumEnabled()) + if (context.GetRequest().GetResult()->GetChecksumContext().responseChecksumEnabled) { for (const auto& hashIterator : context.GetRequest().GetResult()->GetResponseValidationHashes()) { diff --git a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp index 0b7bf4dd0a4..5b17661d43e 100644 --- a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp +++ b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp @@ -918,8 +918,13 @@ void AWSClient::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, co httpRequest->SetDataSentEventHandler(request.GetDataSentEventHandler()); httpRequest->SetContinueRequestHandle(request.GetContinueRequestHandler()); httpRequest->SetServiceSpecificParameters(request.GetServiceSpecificParameters()); - httpRequest->SetChecksumContext({request.ShouldValidateResponseChecksum(), - request.GetResponseChecksumAlgorithmNames()}); + ChecksumContext context; + context.responseChecksumEnabled = request.ShouldValidateResponseChecksum(); + context.canAwsChunkRequest = request.IsStreaming(); + context.requestChecksumAlgorithmName = request.GetChecksumAlgorithmName(); + context.responseChecksums = request.GetResponseChecksumAlgorithmNames(); + context.serviceSpecificParameters = request.GetServiceSpecificParameters(); + httpRequest->SetChecksumContext(context); request.AddQueryStringParameters(httpRequest->GetUri()); } @@ -1011,14 +1016,6 @@ Aws::String AWSClient::GeneratePresignedUrl(const Aws::AmazonWebServiceRequest& return AWSUrlPresigner(*this).GeneratePresignedUrl(request, uri, method, extraParams, expirationInSeconds, serviceSpecificParameter); } -std::shared_ptr AWSClient::GetBodyStream(const Aws::AmazonWebServiceRequest& request) const { - if (request.GetBody() != nullptr) { - return request.GetBody(); - } - // Return an empty string stream for no body - return Aws::MakeShared(AWS_CLIENT_LOG_TAG, ""); -} - std::shared_ptr AWSClient::MakeHttpRequest(std::shared_ptr& request) const { return m_httpClient->MakeRequest(request, m_readRateLimiter.get(), m_writeRateLimiter.get());