diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java index 6ab69cb24a92..c7fdbab7ffee 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeInterceptorSpec.java @@ -50,6 +50,7 @@ import software.amazon.awssdk.core.internal.util.MetricUtils; import software.amazon.awssdk.core.metrics.CoreMetric; import software.amazon.awssdk.endpoints.EndpointProvider; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; @@ -148,20 +149,17 @@ private MethodSpec generateAuthSchemeParams() { if (!authSchemeSpecUtils.useEndpointBasedAuthProvider()) { builder.addStatement("$T operation = executionAttributes.getAttribute($T.OPERATION_NAME)", String.class, SdkExecutionAttribute.class); + builder.addStatement("$T.Builder builder = $T.builder().operation(operation)", + authSchemeSpecUtils.parametersInterfaceName(), + authSchemeSpecUtils.parametersInterfaceName()); + if (authSchemeSpecUtils.usesSigV4()) { builder.addStatement("$T region = executionAttributes.getAttribute($T.AWS_REGION)", Region.class, AwsExecutionAttribute.class); - builder.addStatement("return $T.builder()" - + ".operation(operation)" - + ".region(region)" - + ".build()", - authSchemeSpecUtils.parametersInterfaceName()); - } else { - builder.addStatement("return $T.builder()" - + ".operation(operation)" - + ".build()", - authSchemeSpecUtils.parametersInterfaceName()); + builder.addStatement("builder.region(region)"); } + generateSigv4aRegionSet(builder); + builder.addStatement("return builder.build()"); return builder.build(); } @@ -198,6 +196,7 @@ private MethodSpec generateAuthSchemeParams() { builder.addStatement("(($T)builder).endpointProvider(($T)endpointProvider)", paramsBuilderClass, endpointProviderClass); builder.endControlFlow(); builder.endControlFlow(); + // TODO: Implement addRegionSet() for legacy services that resolve authentication from endpoints in one of next PRs. builder.addStatement("return builder.build()"); return builder.build(); } @@ -449,4 +448,23 @@ private TypeName toTypeName(Object valueType) { } return result; } + + private void generateSigv4aRegionSet(MethodSpec.Builder builder) { + if (authSchemeSpecUtils.usesSigV4a()) { + builder.addStatement( + "$T regionSet = executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" + + " .filter(regions -> !regions.isEmpty())\n" + + " .map(regions -> $T.create(String.join(\", \", regions)))\n" + + " .orElseGet(() -> {\n" + + " $T fallbackRegion = executionAttributes.getAttribute($T.AWS_REGION);\n" + + " return fallbackRegion != null ? $T.create(fallbackRegion.toString()) : null;\n" + + " });", + RegionSet.class, AwsExecutionAttribute.class, + RegionSet.class, Region.class, AwsExecutionAttribute.class, + RegionSet.class + ); + + builder.addStatement("builder.regionSet(regionSet)"); + } + } } diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java index c12984e7bdc5..dba6bca98c74 100644 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java +++ b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/auth/scheme/AuthSchemeSpecTest.java @@ -208,6 +208,12 @@ static List parameters() { .classSpecProvider(DefaultAuthSchemeParamsSpec::new) .caseName("ops-auth-sigv4a-value") .outputFileSuffix("default-params") + .build(), + TestCase.builder() + .modelProvider(ClientTestModels::opsWithSigv4a) + .classSpecProvider(AuthSchemeInterceptorSpec::new) + .caseName("ops-auth-sigv4a-value") + .outputFileSuffix("interceptor") .build() ); } diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/ops-auth-sigv4a-value-auth-scheme-interceptor.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/ops-auth-sigv4a-value-auth-scheme-interceptor.java new file mode 100644 index 000000000000..b2117fd2af76 --- /dev/null +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/ops-auth-sigv4a-value-auth-scheme-interceptor.java @@ -0,0 +1,159 @@ +package software.amazon.awssdk.services.database.auth.scheme.internal; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import software.amazon.awssdk.annotations.Generated; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkException; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.internal.util.MetricUtils; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthScheme; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.http.auth.spi.signer.HttpSigner; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.identity.spi.IdentityProvider; +import software.amazon.awssdk.identity.spi.IdentityProviders; +import software.amazon.awssdk.identity.spi.ResolveIdentityRequest; +import software.amazon.awssdk.identity.spi.TokenIdentity; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.metrics.SdkMetric; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeParams; +import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Validate; + +@Generated("software.amazon.awssdk:codegen") +@SdkInternalApi +public final class DatabaseAuthSchemeInterceptor implements ExecutionInterceptor { + private static Logger LOG = Logger.loggerFor(DatabaseAuthSchemeInterceptor.class); + + @Override + public void beforeExecution(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { + List authOptions = resolveAuthOptions(context, executionAttributes); + SelectedAuthScheme selectedAuthScheme = selectAuthScheme(authOptions, executionAttributes); + putSelectedAuthScheme(executionAttributes, selectedAuthScheme); + } + + private List resolveAuthOptions(Context.BeforeExecution context, ExecutionAttributes executionAttributes) { + DatabaseAuthSchemeProvider authSchemeProvider = Validate.isInstanceOf(DatabaseAuthSchemeProvider.class, + executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEME_RESOLVER), + "Expected an instance of DatabaseAuthSchemeProvider"); + DatabaseAuthSchemeParams params = authSchemeParams(context.request(), executionAttributes); + return authSchemeProvider.resolveAuthScheme(params); + } + + private SelectedAuthScheme selectAuthScheme(List authOptions, + ExecutionAttributes executionAttributes) { + MetricCollector metricCollector = executionAttributes.getAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + Map> authSchemes = executionAttributes.getAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES); + IdentityProviders identityProviders = executionAttributes.getAttribute(SdkInternalExecutionAttribute.IDENTITY_PROVIDERS); + List> discardedReasons = new ArrayList<>(); + for (AuthSchemeOption authOption : authOptions) { + AuthScheme authScheme = authSchemes.get(authOption.schemeId()); + SelectedAuthScheme selectedAuthScheme = trySelectAuthScheme(authOption, authScheme, + identityProviders, discardedReasons, metricCollector, executionAttributes); + if (selectedAuthScheme != null) { + if (!discardedReasons.isEmpty()) { + LOG.debug(() -> String.format("%s auth will be used, discarded: '%s'", authOption.schemeId(), + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", ")))); + } + return selectedAuthScheme; + } + } + throw SdkException + .builder() + .message( + "Failed to determine how to authenticate the user: " + + discardedReasons.stream().map(Supplier::get).collect(Collectors.joining(", "))).build(); + } + + private DatabaseAuthSchemeParams authSchemeParams(SdkRequest request, ExecutionAttributes executionAttributes) { + String operation = executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME); + DatabaseAuthSchemeParams.Builder builder = DatabaseAuthSchemeParams.builder().operation(operation); + Region region = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION); + builder.region(region); + RegionSet regionSet = executionAttributes.getOptionalAttribute(AwsExecutionAttribute.AWS_SIGV4A_SIGNING_REGION_SET) + .filter(regions -> !regions.isEmpty()).map(regions -> RegionSet.create(String.join(", ", regions))) + .orElseGet(() -> { + Region fallbackRegion = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION); + return fallbackRegion != null ? RegionSet.create(fallbackRegion.toString()) : null; + }); + ; + builder.regionSet(regionSet); + return builder.build(); + } + + private SelectedAuthScheme trySelectAuthScheme(AuthSchemeOption authOption, AuthScheme authScheme, + IdentityProviders identityProviders, List> discardedReasons, MetricCollector metricCollector, + ExecutionAttributes executionAttributes) { + if (authScheme == null) { + discardedReasons.add(() -> String.format("'%s' is not enabled for this request.", authOption.schemeId())); + return null; + } + IdentityProvider identityProvider = authScheme.identityProvider(identityProviders); + if (identityProvider == null) { + discardedReasons + .add(() -> String.format("'%s' does not have an identity provider configured.", authOption.schemeId())); + return null; + } + HttpSigner signer; + try { + signer = authScheme.signer(); + } catch (RuntimeException e) { + discardedReasons.add(() -> String.format("'%s' signer could not be retrieved: %s", authOption.schemeId(), + e.getMessage())); + return null; + } + ResolveIdentityRequest.Builder identityRequestBuilder = ResolveIdentityRequest.builder(); + authOption.forEachIdentityProperty(identityRequestBuilder::putProperty); + CompletableFuture identity; + SdkMetric metric = getIdentityMetric(identityProvider); + if (metric == null) { + identity = identityProvider.resolveIdentity(identityRequestBuilder.build()); + } else { + identity = MetricUtils.reportDuration(() -> identityProvider.resolveIdentity(identityRequestBuilder.build()), + metricCollector, metric); + } + return new SelectedAuthScheme<>(identity, signer, authOption); + } + + private SdkMetric getIdentityMetric(IdentityProvider identityProvider) { + Class identityType = identityProvider.identityType(); + if (identityType == AwsCredentialsIdentity.class) { + return CoreMetric.CREDENTIALS_FETCH_DURATION; + } + if (identityType == TokenIdentity.class) { + return CoreMetric.TOKEN_FETCH_DURATION; + } + return null; + } + + private void putSelectedAuthScheme(ExecutionAttributes attributes, + SelectedAuthScheme selectedAuthScheme) { + SelectedAuthScheme existingAuthScheme = attributes.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (existingAuthScheme != null) { + AuthSchemeOption.Builder selectedOption = selectedAuthScheme.authSchemeOption().toBuilder(); + existingAuthScheme.authSchemeOption().forEachIdentityProperty(selectedOption::putIdentityPropertyIfAbsent); + existingAuthScheme.authSchemeOption().forEachSignerProperty(selectedOption::putSignerPropertyIfAbsent); + selectedAuthScheme = new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), + selectedOption.build()); + } + attributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } +} diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java index 48edb00b1855..0b30a534901e 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/auth/scheme/query-auth-scheme-interceptor.java @@ -84,10 +84,11 @@ private SelectedAuthScheme selectAuthScheme(List SelectedAuthScheme trySelectAuthScheme(AuthSchemeOption authOption, AuthScheme authScheme, IdentityProviders identityProviders, List> discardedReasons, MetricCollector metricCollector, ExecutionAttributes executionAttributes) { diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/c2j/ops-with-auth-sigv4a-value/service-2.json b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/c2j/ops-with-auth-sigv4a-value/service-2.json index abbff04b72b6..313162ffdd09 100644 --- a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/c2j/ops-with-auth-sigv4a-value/service-2.json +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/client/c2j/ops-with-auth-sigv4a-value/service-2.json @@ -6,7 +6,7 @@ "globalEndpoint": "database-service.amazonaws.com", "protocol": "rest-json", "serviceAbbreviation": "Database Service", - "serviceFullName": "Some Service That Uses AWS Database Protocol", + "serviceFullName": "Some Service That Uses AWS Database Protocol With Sigv4a", "serviceId": "Database Service", "signingName": "database-service", "signatureVersion": "v4", diff --git a/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/customization.config b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/customization.config new file mode 100644 index 000000000000..28574274a7ef --- /dev/null +++ b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/customization.config @@ -0,0 +1,4 @@ +{ + "skipEndpointTestGeneration": true, + "useMultiAuth": true +} \ No newline at end of file diff --git a/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/endpoint-rule-set.json b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/endpoint-rule-set.json new file mode 100644 index 000000000000..cf58fb6fe996 --- /dev/null +++ b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/endpoint-rule-set.json @@ -0,0 +1,375 @@ +{ + "version": "1.3", + "parameters": { + "Region": { + "builtIn": "AWS::Region", + "required": true, + "documentation": "The AWS region used to dispatch the request.", + "type": "String" + }, + "UseDualStack": { + "builtIn": "AWS::UseDualStack", + "required": true, + "default": false, + "documentation": "When true, use the dual-stack endpoint. If the configured endpoint does not support dual-stack, dispatching the request MAY return an error.", + "type": "Boolean" + }, + "UseFIPS": { + "builtIn": "AWS::UseFIPS", + "required": true, + "default": false, + "documentation": "When true, send this request to the FIPS-compliant regional endpoint. If the configured endpoint does not have a FIPS compliant endpoint, dispatching the request will return an error.", + "type": "Boolean" + }, + "Endpoint": { + "builtIn": "SDK::Endpoint", + "required": false, + "documentation": "Override the endpoint used to send this request", + "type": "String" + }, + "StaticStringParam": { + "type": "String", + "required": false + }, + "OperationContextParam": { + "type": "String", + "required": false + }, + "RegionWithDefault": { + "type": "String", + "required": true, + "default": "us-east-1", + "builtIn": "AWS::Region" + }, + "BooleanClientContextParam": { + "type": "Boolean" + }, + "StringClientContextParam": { + "type": "String" + } + }, + "rules": [ + { + "conditions": [ + { + "fn": "aws.partition", + "argv": [ + { + "ref": "Region" + } + ], + "assign": "PartitionResult" + } + ], + "type": "tree", + "rules": [ + { + "conditions": [ + { + "fn": "isSet", + "argv": [ + { + "ref": "Endpoint" + } + ] + }, + { + "fn": "parseURL", + "argv": [ + { + "ref": "Endpoint" + } + ], + "assign": "url" + } + ], + "type": "tree", + "rules": [ + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseFIPS" + }, + true + ] + } + ], + "error": "Invalid Configuration: FIPS and custom endpoint are not supported", + "type": "error" + }, + { + "conditions": [], + "type": "tree", + "rules": [ + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseDualStack" + }, + true + ] + } + ], + "error": "Invalid Configuration: Dualstack and custom endpoint are not supported", + "type": "error" + }, + { + "conditions": [], + "endpoint": { + "url": { + "ref": "Endpoint" + }, + "properties": { + "authSchemes": [ + { + "name": "sigv4", + "signingRegion": "{Region}", + "signingName": "restjson" + } + ] + }, + "headers": {} + }, + "type": "endpoint" + } + ] + } + ] + }, + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseFIPS" + }, + true + ] + }, + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseDualStack" + }, + true + ] + } + ], + "type": "tree", + "rules": [ + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + true, + { + "fn": "getAttr", + "argv": [ + { + "ref": "PartitionResult" + }, + "supportsFIPS" + ] + } + ] + }, + { + "fn": "booleanEquals", + "argv": [ + true, + { + "fn": "getAttr", + "argv": [ + { + "ref": "PartitionResult" + }, + "supportsDualStack" + ] + } + ] + } + ], + "type": "tree", + "rules": [ + { + "conditions": [], + "endpoint": { + "url": "https://restjson-fips.{Region}.{PartitionResult#dualStackDnsSuffix}", + "properties": { + "authSchemes": [ + { + "name": "sigv4", + "signingRegion": "{Region}", + "signingName": "restjson" + } + ] + }, + "headers": {} + }, + "type": "endpoint" + } + ] + }, + { + "conditions": [], + "error": "FIPS and DualStack are enabled, but this partition does not support one or both", + "type": "error" + } + ] + }, + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseFIPS" + }, + true + ] + } + ], + "type": "tree", + "rules": [ + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + true, + { + "fn": "getAttr", + "argv": [ + { + "ref": "PartitionResult" + }, + "supportsFIPS" + ] + } + ] + } + ], + "type": "tree", + "rules": [ + { + "conditions": [], + "type": "tree", + "rules": [ + { + "conditions": [], + "endpoint": { + "url": "https://restjson-fips.{Region}.{PartitionResult#dnsSuffix}", + "properties": { + "authSchemes": [ + { + "name": "sigv4", + "signingRegion": "{Region}", + "signingName": "restjson" + } + ] + }, + "headers": {} + }, + "type": "endpoint" + } + ] + } + ] + }, + { + "conditions": [], + "error": "FIPS is enabled but this partition does not support FIPS", + "type": "error" + } + ] + }, + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + { + "ref": "UseDualStack" + }, + true + ] + } + ], + "type": "tree", + "rules": [ + { + "conditions": [ + { + "fn": "booleanEquals", + "argv": [ + true, + { + "fn": "getAttr", + "argv": [ + { + "ref": "PartitionResult" + }, + "supportsDualStack" + ] + } + ] + } + ], + "type": "tree", + "rules": [ + { + "conditions": [], + "endpoint": { + "url": "https://restjson.{Region}.{PartitionResult#dualStackDnsSuffix}", + "properties": { + "authSchemes": [ + { + "name": "sigv4", + "signingRegion": "{Region}", + "signingName": "restjson" + } + ] + }, + "headers": {} + }, + "type": "endpoint" + } + ] + }, + { + "conditions": [], + "error": "DualStack is enabled but this partition does not support DualStack", + "type": "error" + } + ] + }, + { + "conditions": [], + "endpoint": { + "url": "https://restjson.{Region}.{PartitionResult#dnsSuffix}", + "properties": { + "authSchemes": [ + { + "name": "sigv4", + "signingRegion": "{Region}", + "signingName": "restjson" + } + ] + }, + "headers": {} + }, + "type": "endpoint" + } + ] + } + ] +} \ No newline at end of file diff --git a/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/endpoint-tests.json b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/endpoint-tests.json new file mode 100644 index 000000000000..f94902ff9d99 --- /dev/null +++ b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/endpoint-tests.json @@ -0,0 +1,5 @@ +{ + "testCases": [ + ], + "version": "1.0" +} \ No newline at end of file diff --git a/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/service-2.json b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/service-2.json new file mode 100644 index 000000000000..e0be3bee5ca7 --- /dev/null +++ b/test/codegen-generated-classes-test/src/main/resources/codegen-resources/multiauth/service-2.json @@ -0,0 +1,47 @@ +{ + "version":"2.0", + "metadata":{ + "apiVersion":"2016-03-11", + "endpointPrefix":"internalconfig", + "jsonVersion":"1.1", + "protocol":"rest-json", + "serviceAbbreviation":"AwsMultiAuthService", + "serviceFullName":"AWS Multi Auth Service", + "serviceId":"Multiauth", + "signatureVersion":"v4", + "targetPrefix":"MultiAuth", + "timestampFormat":"unixTimestamp", + "uid":"restjson-2016-03-11" + }, + "operations":{ + "sigv4aOperation":{ + "name":"sigv4a", + "http":{ + "method":"POST", + "requestUri":"/2016-03-11/sigv4aoperation" + }, + "input":{"shape":"sigv4aShape"}, + "auth": ["aws.auth#sigv4a"] + }, + "sigv4AndSigv4aOperation":{ + "name":"sigv4a", + "http":{ + "method":"POST", + "requestUri":"/2016-03-11/sigv4andsigv4aoperation" + }, + "input":{"shape":"sigv4aShape"}, + "auth": ["aws.auth#sigv4a", "aws.auth#sigv4"] + } + }, + "shapes": { + "sigv4aShape": { + "type": "structure", + "members": { + "StringMember": { + "shape": "String" + } + } + }, + "String":{"type":"string"} + } +} diff --git a/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/multiauth/Sigv4aMultiAuthTest.java b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/multiauth/Sigv4aMultiAuthTest.java new file mode 100644 index 000000000000..8aabc1dd39bf --- /dev/null +++ b/test/codegen-generated-classes-test/src/test/java/software/amazon/awssdk/services/multiauth/Sigv4aMultiAuthTest.java @@ -0,0 +1,143 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.multiauth; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.core.SdkSystemSetting; +import software.amazon.awssdk.http.HttpExecuteRequest; +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4AuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.multiauth.auth.scheme.MultiauthAuthSchemeParams; +import software.amazon.awssdk.services.multiauth.auth.scheme.MultiauthAuthSchemeProvider; +import software.amazon.awssdk.testutils.EnvironmentVariableHelper; + +/** + * Unit tests for the Sigv4a multi-auth functionality. + */ +class Sigv4aMultiAuthTest { + + private EnvironmentVariableHelper environmentVariableHelper; + private SdkHttpClient mockHttpClient; + private MultiauthAuthSchemeProvider multiauthAuthSchemeProvider; + + @BeforeEach + void setUp() { + environmentVariableHelper = new EnvironmentVariableHelper(); + multiauthAuthSchemeProvider = mock(MultiauthAuthSchemeProvider.class); + + mockHttpClient = mock(SdkHttpClient.class); + when(mockHttpClient.clientName()).thenReturn("MockHttpClient"); + when(mockHttpClient.prepareRequest(any())).thenThrow(new RuntimeException("expected exception")); + + List authSchemeOptions = Collections.singletonList( + AuthSchemeOption.builder().schemeId(AwsV4AuthScheme.SCHEME_ID).build() + ); + when(multiauthAuthSchemeProvider.resolveAuthScheme(any(MultiauthAuthSchemeParams.class))) + .thenReturn(authSchemeOptions); + } + + @AfterEach + void tearDown() { + environmentVariableHelper.reset(); + } + + @Test + void requestHasRegionSetParamsUpdatedToRegion() { + environmentVariableHelper.set(SdkSystemSetting.AWS_SIGV4A_SIGNING_REGION_SET, "us-west-2,us-west-1"); + + MultiauthClient multiauthClient = MultiauthClient.builder() + .httpClient(mockHttpClient) + .authSchemeProvider(multiauthAuthSchemeProvider) + .region(Region.US_WEST_2) + .build(); + + assertThatThrownBy(() -> multiauthClient.sigv4aOperation(r -> r.stringMember(""))) + .hasMessageContaining("expected exception"); + + ArgumentCaptor paramsCaptor = + ArgumentCaptor.forClass(MultiauthAuthSchemeParams.class); + verify(multiauthAuthSchemeProvider).resolveAuthScheme(paramsCaptor.capture()); + + MultiauthAuthSchemeParams resolvedAuthSchemeParams = paramsCaptor.getValue(); + assertThat(resolvedAuthSchemeParams.regionSet()) + .isEqualTo(RegionSet.create(Arrays.asList("us-west-2", "us-west-1"))); + } + + @Test + void requestHasRegionSetSdkSystemSettings() { + MultiauthClient multiauthClient = MultiauthClient.builder() + .httpClient(mockHttpClient) + .authSchemeProvider(multiauthAuthSchemeProvider) + .region(Region.US_WEST_2) + .build(); + + assertThatThrownBy(() -> multiauthClient.sigv4aOperation(r -> r.stringMember(""))) + .hasMessageContaining("expected exception"); + + ArgumentCaptor paramsCaptor = + ArgumentCaptor.forClass(MultiauthAuthSchemeParams.class); + verify(multiauthAuthSchemeProvider).resolveAuthScheme(paramsCaptor.capture()); + + MultiauthAuthSchemeParams resolvedAuthSchemeParams = paramsCaptor.getValue(); + assertThat(resolvedAuthSchemeParams.regionSet()) + .isEqualTo(RegionSet.create(Region.US_WEST_2.toString())); + } + + @Test + void errorWhenSigv4aDoesNotHasFallbackSigv4() { + MultiauthClient multiauthClient = MultiauthClient.builder() + .httpClient(mockHttpClient) + .region(Region.US_WEST_2) + .build(); + + assertThatThrownBy(() -> multiauthClient.sigv4aOperation(r -> r.stringMember(""))) + .hasMessageContaining("You must add a dependency on the 'software.amazon.awssdk:http-auth-aws-crt' " + + "module to enable the CRT-V4a signing feature"); + } + + @Test + void fallBackToSigv4WhenSigv4aIsNotAvailable() { + MultiauthClient multiauthClient = MultiauthClient.builder() + .httpClient(mockHttpClient) + .region(Region.US_WEST_2) + .build(); + + assertThatThrownBy(() -> multiauthClient.sigv4AndSigv4aOperation(r -> r.stringMember(""))) + .hasMessageContaining("expected exception"); + + ArgumentCaptor httpRequestCaptor = ArgumentCaptor.forClass(HttpExecuteRequest.class); + verify(mockHttpClient).prepareRequest(httpRequestCaptor.capture()); + SdkHttpRequest request = httpRequestCaptor.getAllValues().get(0).httpRequest(); + assertThat(request.firstMatchingHeader("Authorization")).isPresent(); + } +}