Skip to content

Commit

Permalink
wip - everything in place, integration tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Oct 7, 2024
1 parent d6d2ce7 commit 3d33120
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 50 deletions.
1 change: 0 additions & 1 deletion src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ namespace Aws
void AddContentBodyToRequest(const std::shared_ptr<Aws::Http::HttpRequest>& httpRequest, const std::shared_ptr<Aws::IOStream>& body,
bool needsContentMd5 = false, bool isChunked = false) const;
void AddCommonHeaders(Aws::Http::HttpRequest& httpRequest) const;
std::shared_ptr<Aws::IOStream> GetBodyStream(const Aws::AmazonWebServiceRequest& request) const;
void AppendHeaderValueToRequest(const std::shared_ptr<Http::HttpRequest> &request, String header, String value) const;

std::shared_ptr<Aws::Http::HttpClient> m_httpClient;
Expand Down
33 changes: 0 additions & 33 deletions src/aws-cpp-sdk-core/include/aws/core/http/ChecksumContext.h

This file was deleted.

14 changes: 11 additions & 3 deletions src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

#include <aws/core/http/URI.h>
#include <aws/core/http/HttpTypes.h>
#include <aws/core/http/ChecksumContext.h>
#include <aws/core/utils/memory/AWSMemory.h>
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
#include <aws/core/utils/stream/ResponseStream.h>
Expand Down Expand Up @@ -93,6 +92,15 @@ namespace Aws
Aws::Map<Aws::String, Aws::String> parameterMap;
};

struct ChecksumContext
{
bool responseChecksumEnabled{};
bool canAwsChunkRequest{};
Aws::String requestChecksumAlgorithmName{};
Aws::Vector<Aws::String> responseChecksums{};
std::shared_ptr<Aws::Http::ServiceSpecificParameters> serviceSpecificParameters{nullptr};
};

/**
* Abstract class for representing an HttpRequest.
*/
Expand Down Expand Up @@ -589,7 +597,7 @@ namespace Aws

inline std::shared_ptr<ServiceSpecificParameters> GetServiceSpecificParameters() { return m_serviceSpecificParameters; }

inline ChecksumContext GetChecksumContext() const { return m_checksumContext;}
ChecksumContext GetChecksumContext() const { return m_checksumContext;}

inline void SetChecksumContext(const ChecksumContext& m_checksum_context) { m_checksumContext = m_checksum_context; }

Expand All @@ -608,7 +616,7 @@ namespace Aws
std::pair<Aws::String, std::shared_ptr<Aws::Utils::Crypto::Hash>> m_requestHash;
Aws::Vector<std::pair<Aws::String, std::shared_ptr<Aws::Utils::Crypto::Hash>>> m_responseValidationHashes;
std::shared_ptr<ServiceSpecificParameters> m_serviceSpecificParameters;
ChecksumContext m_checksumContext{false, {}};
ChecksumContext m_checksumContext;
};

} // namespace Http
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once
#include <smithy/interceptor/Interceptor.h>

#include <aws/core/utils/crypto/PrecalculatedHash.h>
#include <aws/core/utils/crypto/CRC32.h>
#include <aws/core/utils/crypto/Sha1.h>
#include <aws/core/utils/crypto/Sha256.h>
Expand All @@ -19,6 +20,7 @@ namespace smithy
class ChecksumInterceptor final : public Interceptor
{
public:

ChecksumInterceptor() = default;
ChecksumInterceptor(const ChecksumInterceptor& other) = delete;
ChecksumInterceptor(ChecksumInterceptor&& other) noexcept = default;
Expand All @@ -29,10 +31,103 @@ namespace smithy
ModifyRequestOutcome ModifyRequest(InterceptorContext& context) override
{
auto httpRequest = context.GetRequest().GetResult();
const auto checksumContext = httpRequest->GetChecksumContext();
const auto serviceSpecificParams = httpRequest->GetChecksumContext().serviceSpecificParameters;
// Request checksums
Aws::String checksumAlgorithmName = Aws::Utils::StringUtils::ToLower(checksumContext.requestChecksumAlgorithmName.c_str());
if (serviceSpecificParams) {
auto requestChecksumOverride = serviceSpecificParams->parameterMap.find("overrideChecksum");
if (requestChecksumOverride != serviceSpecificParams->parameterMap.end()) {
checksumAlgorithmName = requestChecksumOverride->second;
}
}

bool shouldSkipChecksum = serviceSpecificParams &&
serviceSpecificParams->parameterMap.find("overrideChecksumDisable") !=
serviceSpecificParams->parameterMap.end();

//Check if user has provided the checksum algorithm
if (!checksumAlgorithmName.empty() && !shouldSkipChecksum)
{
// Check if user has provided a checksum value for the specified algorithm
const Aws::String checksumType = "x-amz-checksum-" + checksumAlgorithmName;
const Aws::Http::HeaderValueCollection &headers = httpRequest->GetHeaders();
const auto checksumHeader = headers.find(checksumType);
bool checksumValueAndAlgorithmProvided = checksumHeader != headers.end();

// For non-streaming payload, the resolved checksum location is always header.
// For streaming payload, the resolved checksum location depends on whether it is an unsigned payload, we let AwsAuthSigner decide it.
if (checksumContext.canAwsChunkRequest && checksumValueAndAlgorithmProvided)
{
const auto hash = Aws::MakeShared<Aws::Utils::Crypto::PrecalculatedHash>(CHECKSUM_INTERCEPTOR_LOG_TAG, checksumHeader->second);
httpRequest->SetRequestHash(checksumAlgorithmName,hash);
}
else if (checksumValueAndAlgorithmProvided){
httpRequest->SetHeaderValue(checksumType, checksumHeader->second);
}
else if (checksumAlgorithmName == "crc32")
{
if (checksumContext.canAwsChunkRequest)
{
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<Aws::Utils::Crypto::CRC32>(CHECKSUM_INTERCEPTOR_LOG_TAG));
}
else
{
const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared<Aws::StringStream>(CHECKSUM_INTERCEPTOR_LOG_TAG, "");
httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateCRC32(*body)));
}
}
else if (checksumAlgorithmName == "crc32c")
{
if (checksumContext.canAwsChunkRequest)
{
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<Aws::Utils::Crypto::CRC32C>(CHECKSUM_INTERCEPTOR_LOG_TAG));
}
else
{
const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared<Aws::StringStream>(CHECKSUM_INTERCEPTOR_LOG_TAG, "");
httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateCRC32C(*body)));
}
}
else if (checksumAlgorithmName == "sha256")
{
if (checksumContext.canAwsChunkRequest)
{
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<Aws::Utils::Crypto::Sha256>(CHECKSUM_INTERCEPTOR_LOG_TAG));
}
else
{
const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared<Aws::StringStream>(CHECKSUM_INTERCEPTOR_LOG_TAG, "");
httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateSHA256(*body)));
}
}
else if (checksumAlgorithmName == "sha1")
{
if (checksumContext.canAwsChunkRequest)
{
httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared<Aws::Utils::Crypto::Sha1>(CHECKSUM_INTERCEPTOR_LOG_TAG));
}
else
{
const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared<Aws::StringStream>(CHECKSUM_INTERCEPTOR_LOG_TAG, "");
httpRequest->SetHeaderValue(checksumType, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateSHA1(*body)));
}
}
else if (checksumAlgorithmName == "md5" && headers.find(Aws::Http::CONTENT_MD5_HEADER) == headers.end())
{
const auto body = httpRequest->GetContentBody()? httpRequest->GetContentBody() : Aws::MakeShared<Aws::StringStream>(CHECKSUM_INTERCEPTOR_LOG_TAG, "");
httpRequest->SetHeaderValue(Aws::Http::CONTENT_MD5_HEADER, Aws::Utils::HashingUtils::Base64Encode(Aws::Utils::HashingUtils::CalculateMD5(*body)));
}
else if (headers.find(Aws::Http::CONTENT_MD5_HEADER) == headers.end())
{
AWS_LOGSTREAM_WARN(CHECKSUM_INTERCEPTOR_LOG_TAG, "Checksum algorithm: " << checksumAlgorithmName << " is not supported by SDK.");
}
}

// Response checksums
if (httpRequest->GetChecksumContext().ResponseChecksumEnabled())
if (checksumContext.responseChecksumEnabled)
{
for (const Aws::String& responseChecksumAlgorithmName : httpRequest->GetChecksumContext().ResponseChecksums())
for (const Aws::String& responseChecksumAlgorithmName : checksumContext.responseChecksums)
{
const auto lowered = Aws::Utils::StringUtils::ToLower(responseChecksumAlgorithmName.c_str());

Expand Down Expand Up @@ -67,7 +162,7 @@ namespace smithy

ModifyResponseOutcome ModifyResponse(InterceptorContext& context) override
{
if (context.GetRequest().GetResult()->GetChecksumContext().ResponseChecksumEnabled())
if (context.GetRequest().GetResult()->GetChecksumContext().responseChecksumEnabled)
{
for (const auto& hashIterator : context.GetRequest().GetResult()->GetResponseValidationHashes())
{
Expand Down
17 changes: 7 additions & 10 deletions src/aws-cpp-sdk-core/source/client/AWSClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,8 +918,13 @@ void AWSClient::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, co
httpRequest->SetDataSentEventHandler(request.GetDataSentEventHandler());
httpRequest->SetContinueRequestHandle(request.GetContinueRequestHandler());
httpRequest->SetServiceSpecificParameters(request.GetServiceSpecificParameters());
httpRequest->SetChecksumContext({request.ShouldValidateResponseChecksum(),
request.GetResponseChecksumAlgorithmNames()});
ChecksumContext context;
context.responseChecksumEnabled = request.ShouldValidateResponseChecksum();
context.canAwsChunkRequest = request.IsStreaming();
context.requestChecksumAlgorithmName = request.GetChecksumAlgorithmName();
context.responseChecksums = request.GetResponseChecksumAlgorithmNames();
context.serviceSpecificParameters = request.GetServiceSpecificParameters();
httpRequest->SetChecksumContext(context);
request.AddQueryStringParameters(httpRequest->GetUri());
}

Expand Down Expand Up @@ -1011,14 +1016,6 @@ Aws::String AWSClient::GeneratePresignedUrl(const Aws::AmazonWebServiceRequest&
return AWSUrlPresigner(*this).GeneratePresignedUrl(request, uri, method, extraParams, expirationInSeconds, serviceSpecificParameter);
}

std::shared_ptr<Aws::IOStream> AWSClient::GetBodyStream(const Aws::AmazonWebServiceRequest& request) const {
if (request.GetBody() != nullptr) {
return request.GetBody();
}
// Return an empty string stream for no body
return Aws::MakeShared<Aws::StringStream>(AWS_CLIENT_LOG_TAG, "");
}

std::shared_ptr<Aws::Http::HttpResponse> AWSClient::MakeHttpRequest(std::shared_ptr<Aws::Http::HttpRequest>& request) const
{
return m_httpClient->MakeRequest(request, m_readRateLimiter.get(), m_writeRateLimiter.get());
Expand Down

0 comments on commit 3d33120

Please sign in to comment.