Skip to content

Commit

Permalink
Use capturing-interceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
haydenbaker committed Nov 1, 2023
1 parent 2829700 commit 5cf2ed1
Showing 1 changed file with 59 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,42 +15,35 @@

package software.amazon.awssdk.services;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static software.amazon.awssdk.core.client.config.SdkAdvancedClientOption.SIGNER;

import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute;
import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme;
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.interceptor.Context;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
import software.amazon.awssdk.core.signer.Signer;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.endpoints.Endpoint;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4FamilyHttpSigner;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner;
import software.amazon.awssdk.http.auth.aws.signer.RegionSet;
import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.protocolquery.ProtocolQueryAsyncClient;
import software.amazon.awssdk.services.protocolquery.ProtocolQueryClient;
import software.amazon.awssdk.services.protocolquery.auth.scheme.ProtocolQueryAuthSchemeProvider;
Expand All @@ -62,11 +55,11 @@
*/
public class SignerAndEndpointOverridesTest {

private static Signer mockSigner = mock(Signer.class);
private CapturingInterceptor interceptor;

@BeforeAll
static void setup() {
when(mockSigner.sign(any(), any())).thenThrow(new RuntimeException("boom!"));
@BeforeEach
void setup() {
this.interceptor = new CapturingInterceptor();
}

@Test
Expand All @@ -76,22 +69,13 @@ void test_whenV4EndpointAuthSchemeWithSignerOverride_thenEndpointParamsShouldPro
.authSchemeProvider(v4AuthSchemeProviderOverride())
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
.endpointProvider(v4EndpointProviderOverride())
.region(Region.US_WEST_1)
.overrideConfiguration(o -> o.putAdvancedOption(SIGNER, mockSigner))
.overrideConfiguration(o -> o.putAdvancedOption(SIGNER, Aws4Signer.create()).addExecutionInterceptor(interceptor))
.build();

Exception ex = assertThrows(
RuntimeException.class, () -> client.streamingInputOperation(r -> {}, RequestBody.fromString("test"))
);

assertThat(ex.getMessage()).contains("boom!");
verify(mockSigner).sign(
any(SdkHttpFullRequest.class),
argThat(attrs ->
"us-west-1".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id()) &&
"query-test-v4".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME)) &&
Boolean.FALSE.equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE)))
);
assertThatThrownBy(() -> client.allTypes(r -> {})).hasMessageContaining("boom!");
assertEquals("us-west-1", interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id());
assertEquals("query-test-v4", interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME));
assertEquals(Boolean.FALSE, interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE));
}

@Test
Expand All @@ -101,22 +85,13 @@ void testAsync_whenV4EndpointAuthSchemeWithSignerOverride_thenEndpointParamsShou
.authSchemeProvider(v4AuthSchemeProviderOverride())
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
.endpointProvider(v4EndpointProviderOverride())
.region(Region.US_WEST_1)
.overrideConfiguration(o -> o.putAdvancedOption(SIGNER, mockSigner))
.overrideConfiguration(o -> o.putAdvancedOption(SIGNER, Aws4Signer.create()).addExecutionInterceptor(interceptor))
.build();

Exception ex = assertThrows(
RuntimeException.class, () -> client.streamingInputOperation(r -> {}, AsyncRequestBody.fromString("test")).join()
);

assertThat(ex.getMessage()).contains("boom!");
verify(mockSigner).sign(
any(SdkHttpFullRequest.class),
argThat(attrs ->
"us-west-1".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id()) &&
"query-test-v4".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME)) &&
Boolean.FALSE.equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE)))
);
assertThatThrownBy(() -> client.allTypes(r -> {}).join()).hasMessageContaining("boom!");
assertEquals("us-west-1", interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id());
assertEquals("query-test-v4", interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME));
assertEquals(Boolean.FALSE, interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE));
}

@Test
Expand All @@ -126,26 +101,19 @@ void test_whenV4aEndpointAuthSchemeWithSignerOverride_thenEndpointParamsShouldPr
.authSchemeProvider(v4aAuthSchemeProviderOverride())
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
.endpointProvider(v4aEndpointProviderOverride())
.region(Region.US_EAST_1)
.overrideConfiguration(
o -> o.putAdvancedOption(SIGNER, mockSigner)
o -> o.putAdvancedOption(SIGNER, Aws4Signer.create())
.addExecutionInterceptor(interceptor)
.putExecutionAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, Collections.singletonMap(
"aws.auth#sigv4a", AwsV4aAuthScheme.create()
)))
.build();

Exception ex = assertThrows(
RuntimeException.class, () -> client.streamingInputOperation(r -> {}, RequestBody.fromString("test"))
);

assertThat(ex.getMessage()).contains("boom!");
verify(mockSigner).sign(
any(SdkHttpFullRequest.class),
argThat(attrs ->
"us-east-1".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id()) &&
"query-test-v4a".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME)) &&
Boolean.FALSE.equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE)))
);
assertThatThrownBy(() -> client.allTypes(r -> {})).hasMessageContaining("boom!");
assertEquals("us-east-1", interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION_SCOPE).id());
assertEquals("query-test-v4a",
interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME));
assertEquals(Boolean.FALSE, interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE));
}

@Test
Expand All @@ -155,30 +123,23 @@ void testAsync_whenV4aEndpointAuthSchemeWithSignerOverride_thenEndpointParamsSho
.authSchemeProvider(v4aAuthSchemeProviderOverride())
.credentialsProvider(StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")))
.endpointProvider(v4aEndpointProviderOverride())
.region(Region.US_EAST_1)
.overrideConfiguration(
o -> o.putAdvancedOption(SIGNER, mockSigner)
o -> o.putAdvancedOption(SIGNER, Aws4Signer.create())
.addExecutionInterceptor(interceptor)
.putExecutionAttribute(SdkInternalExecutionAttribute.AUTH_SCHEMES, Collections.singletonMap(
"aws.auth#sigv4a", AwsV4aAuthScheme.create()
)))
.build();

Exception ex = assertThrows(
SdkClientException.class, () -> client.streamingInputOperation(r -> {}, AsyncRequestBody.fromString("test")).join()
);

assertThat(ex.getMessage()).contains("boom!");
verify(mockSigner).sign(
any(SdkHttpFullRequest.class),
argThat(attrs ->
"us-east-1".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION).id()) &&
"query-test-v4a".equals(attrs.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME)) &&
Boolean.FALSE.equals(attrs.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE)))
);
assertThatThrownBy(() -> client.allTypes(r -> {}).join()).hasMessageContaining("boom!");
assertEquals("us-east-1", interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNING_REGION_SCOPE).id());
assertEquals("query-test-v4a",
interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME));
assertEquals(Boolean.FALSE, interceptor.executionAttributes.getAttribute(AwsSignerExecutionAttribute.SIGNER_DOUBLE_URL_ENCODE));
}

private static ProtocolQueryAuthSchemeProvider v4AuthSchemeProviderOverride() {
return __ -> {
return x -> {
List<AuthSchemeOption> options = new ArrayList<>();
options.add(
AuthSchemeOption.builder().schemeId("aws.auth#sigv4")
Expand All @@ -191,7 +152,7 @@ private static ProtocolQueryAuthSchemeProvider v4AuthSchemeProviderOverride() {
}

private static ProtocolQueryAuthSchemeProvider v4aAuthSchemeProviderOverride() {
return __ -> {
return x -> {
List<AuthSchemeOption> options = new ArrayList<>();
options.add(
AuthSchemeOption.builder().schemeId("aws.auth#sigv4a")
Expand Down Expand Up @@ -238,4 +199,26 @@ private static ProtocolQueryEndpointProvider v4aEndpointProviderOverride() {
return CompletableFuture.completedFuture(endpoint);
};
}

public static class CapturingInterceptor implements ExecutionInterceptor {
private Context.BeforeTransmission context;
private ExecutionAttributes executionAttributes;

@Override
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
this.context = context;
this.executionAttributes = executionAttributes;
throw new RuntimeException("boom!");
}

public ExecutionAttributes executionAttributes() {
return executionAttributes;
}

public class CaptureCompletedException extends RuntimeException {
CaptureCompletedException(String message) {
super(message);
}
}
}
}

0 comments on commit 5cf2ed1

Please sign in to comment.