Skip to content

Commit 307d75e

Browse files
committed
Moving AuthStrippingInterceptor to registry-aws and adding region as input parameter to ContainerRegistryClient
1 parent 65f76e9 commit 307d75e

File tree

7 files changed

+157
-120
lines changed

7 files changed

+157
-120
lines changed

registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AwsRegistry.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.salesforce.multicloudj.registry.aws;
22

33
import com.google.auto.service.AutoService;
4+
import com.salesforce.multicloudj.common.aws.AwsConstants;
45
import com.salesforce.multicloudj.common.aws.CommonErrorCodeMapping;
56
import com.salesforce.multicloudj.common.aws.CredentialsProvider;
67
import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException;
@@ -9,6 +10,7 @@
910
import com.salesforce.multicloudj.registry.driver.AbstractRegistry;
1011
import com.salesforce.multicloudj.registry.driver.OciRegistryClient;
1112
import org.apache.commons.lang3.StringUtils;
13+
import org.apache.http.HttpRequestInterceptor;
1214
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
1315
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
1416
import software.amazon.awssdk.awscore.exception.AwsServiceException;
@@ -19,6 +21,7 @@
1921

2022
import java.nio.charset.StandardCharsets;
2123
import java.util.Base64;
24+
import java.util.List;
2225

2326
/**
2427
* AWS Elastic Container Registry (ECR) implementation.
@@ -32,7 +35,6 @@
3235
@AutoService(AbstractRegistry.class)
3336
public class AwsRegistry extends AbstractRegistry {
3437

35-
public static final String PROVIDER_ID = "aws";
3638
private static final String AWS_AUTH_USERNAME = "AWS";
3739

3840
/** Lock for thread-safe lazy initialization of ECR client. */
@@ -41,7 +43,6 @@ public class AwsRegistry extends AbstractRegistry {
4143
/** Lock for thread-safe token refresh. */
4244
private final Object tokenLock = new Object();
4345

44-
private final String region;
4546
private final OciRegistryClient ociClient;
4647

4748
/** Lazily initialized ECR client with double-checked locking. */
@@ -67,7 +68,6 @@ public AwsRegistry(Builder builder) {
6768
*/
6869
public AwsRegistry(Builder builder, EcrClient ecrClient) {
6970
super(builder);
70-
this.region = builder.getRegion();
7171
this.ecrClient = ecrClient;
7272
this.ociClient = registryEndpoint != null ? new OciRegistryClient(registryEndpoint, this) : null;
7373
}
@@ -99,6 +99,11 @@ public String getAuthToken() {
9999
return cachedAuthToken;
100100
}
101101

102+
@Override
103+
protected List<HttpRequestInterceptor> getInterceptors() {
104+
return List.of(new AuthStrippingInterceptor(registryEndpoint));
105+
}
106+
102107
/**
103108
* Returns true if the current time is past the halfway point of the token's validity window,
104109
* at which point a proactive refresh is triggered i.e. treats tokens as invalid after 50% of their lifetime.
@@ -198,18 +203,7 @@ public void close() throws Exception {
198203
public static final class Builder extends AbstractRegistry.Builder<AwsRegistry, Builder> {
199204

200205
public Builder() {
201-
providerId(PROVIDER_ID);
202-
}
203-
204-
private String region;
205-
206-
public Builder withRegion(String region) {
207-
this.region = region;
208-
return this;
209-
}
210-
211-
public String getRegion() {
212-
return region;
206+
providerId(AwsConstants.PROVIDER_ID);
213207
}
214208

215209
@Override

registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptorTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class AuthStrippingInterceptorTest {
2020
private static final String REGISTRY_ENDPOINT = "https://123456789012.dkr.ecr.us-east-1.amazonaws.com";
2121
private static final String REGISTRY_HOST = "123456789012.dkr.ecr.us-east-1.amazonaws.com";
2222
private static final String EXAMPLE_ENDPOINT_HOST = "registry.example.com";
23+
private static final String S3_HOST = "s3.amazonaws.com";
2324
private static final String GET = "GET";
2425
private static final String BEARER_TOKEN = "Bearer token123";
2526
private static final String PATH_BLOB = "/v2/repo/blobs/sha256:abc";
@@ -78,7 +79,7 @@ void testProcess_CaseInsensitiveHostComparison() {
7879
@Test
7980
void testProcess_NoAuthHeader_NoError() {
8081
HttpRequest request = new BasicHttpRequest(GET, PATH_BLOB);
81-
HttpContext context = contextWithHost("s3.amazonaws.com");
82+
HttpContext context = contextWithHost(S3_HOST);
8283

8384
interceptor.process(request, context);
8485

registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AwsRegistryTest.java

Lines changed: 108 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException;
55
import com.salesforce.multicloudj.common.exceptions.UnAuthorizedException;
66
import com.salesforce.multicloudj.common.exceptions.UnknownException;
7+
import org.apache.http.HttpRequestInterceptor;
78
import org.junit.jupiter.api.Test;
89
import org.junit.jupiter.params.ParameterizedTest;
910
import org.junit.jupiter.params.provider.MethodSource;
@@ -19,10 +20,14 @@
1920
import java.time.Instant;
2021
import java.util.Base64;
2122
import java.util.Collections;
23+
import java.util.List;
2224
import java.util.stream.Stream;
2325

2426
import static org.junit.jupiter.api.Assertions.assertEquals;
27+
import static org.junit.jupiter.api.Assertions.assertFalse;
28+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
2529
import static org.junit.jupiter.api.Assertions.assertNotNull;
30+
import static org.junit.jupiter.api.Assertions.assertNull;
2631
import static org.junit.jupiter.api.Assertions.assertThrows;
2732
import static org.junit.jupiter.params.provider.Arguments.arguments;
2833
import static org.mockito.ArgumentMatchers.any;
@@ -54,7 +59,6 @@ private AwsRegistry createRegistryWithMockEcrClient(EcrClient mockEcrClient) {
5459
AwsRegistry.Builder builder = new AwsRegistry.Builder()
5560
.withRegion(TEST_REGION)
5661
.withRegistryEndpoint(TEST_REGISTRY_ENDPOINT);
57-
builder.providerId(PROVIDER_ID);
5862
return new AwsRegistry(builder, mockEcrClient);
5963
}
6064

@@ -81,16 +85,6 @@ private static AwsServiceException awsException(String errorCode) {
8185
.build();
8286
}
8387

84-
static Stream<org.junit.jupiter.params.provider.Arguments> exceptionMappingProvider() { // NOSONAR - needed for @MethodSource resolution
85-
return Stream.of(
86-
arguments(new SubstrateSdkException("test"), SubstrateSdkException.class),
87-
arguments(awsException(AWS_ERROR_CODE_ACCESS_DENIED), UnAuthorizedException.class),
88-
arguments(awsException(AWS_ERROR_CODE_UNAVAILABLE), UnknownException.class),
89-
arguments(new IllegalArgumentException("invalid"), InvalidArgumentException.class),
90-
arguments(new RuntimeException("unknown"), UnknownException.class)
91-
);
92-
}
93-
9488
private void withMockedRegistry(RegistryTestAction action) throws Exception {
9589
EcrClient mockEcrClient = mock(EcrClient.class);
9690
try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
@@ -99,62 +93,78 @@ private void withMockedRegistry(RegistryTestAction action) throws Exception {
9993
}
10094

10195
@Test
102-
void testBuilderAndBasicProperties() throws Exception {
96+
void testNoArgConstructor_CreatesInstanceWithDefaultBuilder() {
97+
AwsRegistry registry = new AwsRegistry();
98+
assertNotNull(registry);
99+
assertNull(registry.getOciClient());
100+
}
101+
102+
@Test
103+
void testConstructor_WithBuilder_InitialisesFields() throws Exception {
103104
withMockedRegistry(registry -> {
104-
assertNotNull(registry);
105105
assertEquals(PROVIDER_ID, registry.getProviderId());
106106
assertEquals(AUTH_USERNAME, registry.getAuthUsername());
107-
assertNotNull(registry.builder());
107+
assertNotNull(registry.getOciClient());
108108
});
109109
}
110110

111+
@Test
112+
void testBuilder_InstanceMethod_ReturnsNewBuilder() throws Exception {
113+
withMockedRegistry(registry -> assertNotNull(registry.builder()));
114+
}
115+
116+
@Test
117+
void testBuilder_WithRegion_GetRegion_RoundTrip() {
118+
AwsRegistry.Builder builder = new AwsRegistry.Builder().withRegion(TEST_REGION);
119+
assertEquals(TEST_REGION, builder.getRegion());
120+
}
121+
111122
@ParameterizedTest
112123
@NullAndEmptySource
113124
@ValueSource(strings = {" "})
114-
void testBuilder_InvalidRegion_ThrowsException(String region) {
125+
void testBuilder_MissingEndpoint_ThrowsInvalidArgumentException(String endpoint) {
115126
assertThrows(InvalidArgumentException.class, () ->
116-
new AwsRegistry.Builder()
117-
.withRegion(region)
118-
.withRegistryEndpoint(TEST_REGISTRY_ENDPOINT)
119-
.build()
120-
);
127+
new AwsRegistry.Builder()
128+
.withRegion(TEST_REGION)
129+
.withRegistryEndpoint(endpoint)
130+
.build());
121131
}
122132

123-
@Test
124-
void testGetAuthToken_EmptyAuthorizationData_ThrowsUnknownException() throws Exception {
125-
EcrClient mockEcrClient = mock(EcrClient.class);
126-
GetAuthorizationTokenResponse response = GetAuthorizationTokenResponse.builder()
127-
.authorizationData(Collections.emptyList())
128-
.build();
129-
when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class))).thenReturn(response);
133+
@ParameterizedTest
134+
@NullAndEmptySource
135+
@ValueSource(strings = {" "})
136+
void testBuilder_MissingRegion_ThrowsInvalidArgumentException(String region) {
137+
assertThrows(InvalidArgumentException.class, () ->
138+
new AwsRegistry.Builder()
139+
.withRegion(region)
140+
.withRegistryEndpoint(TEST_REGISTRY_ENDPOINT)
141+
.build());
142+
}
130143

131-
try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
132-
UnknownException exception = assertThrows(UnknownException.class, registry::getAuthToken);
133-
assertEquals(ERR_EMPTY_AUTH_DATA, exception.getMessage());
134-
}
144+
@Test
145+
void testGetInterceptors_ReturnsAuthStrippingInterceptor() throws Exception {
146+
withMockedRegistry(registry -> {
147+
List<HttpRequestInterceptor> interceptors = registry.getInterceptors();
148+
assertFalse(interceptors.isEmpty());
149+
assertInstanceOf(AuthStrippingInterceptor.class, interceptors.get(0));
150+
});
135151
}
136152

137153
@Test
138-
void testGetAuthToken_InvalidTokenFormat_ThrowsUnknownException() throws Exception {
139-
EcrClient mockEcrClient = mock(EcrClient.class);
140-
AuthorizationData authData = AuthorizationData.builder()
141-
.authorizationToken(Base64.getEncoder().encodeToString("invalidtoken".getBytes()))
142-
.expiresAt(Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS))
143-
.build();
144-
when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
145-
.thenReturn(tokenResponse(authData));
154+
void testGetOciClient_ReturnsNonNull_WhenEndpointProvided() throws Exception {
155+
withMockedRegistry(registry -> assertNotNull(registry.getOciClient()));
156+
}
146157

147-
try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
148-
UnknownException exception = assertThrows(UnknownException.class, registry::getAuthToken);
149-
assertEquals(ERR_INVALID_TOKEN_FORMAT, exception.getMessage());
150-
}
158+
@Test
159+
void testGetOciClient_ReturnsNull_WhenNoEndpoint() {
160+
AwsRegistry registry = new AwsRegistry();
161+
assertNull(registry.getOciClient());
151162
}
152163

153164
@Test
154165
void testGetAuthToken_TokenCachedWithinHalfwayWindow_NoRefresh() throws Exception {
155166
String expectedToken = "cached-token";
156167
EcrClient mockEcrClient = mock(EcrClient.class);
157-
// Token expires 12 hours from now — halfway point is 6 hours from now, so no refresh expected
158168
when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
159169
.thenReturn(tokenResponse(authDataWithExpiry(expectedToken, Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS))));
160170

@@ -171,7 +181,6 @@ void testGetAuthToken_TokenPastHalfwayPoint_Refreshes() throws Exception {
171181
String refreshedToken = "refreshed-token";
172182
EcrClient mockEcrClient = mock(EcrClient.class);
173183
when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
174-
// Token already past halfway point (expired 1 second ago)
175184
.thenReturn(tokenResponse(authDataWithExpiry(firstToken, Instant.now().minusSeconds(1))))
176185
.thenReturn(tokenResponse(authDataWithExpiry(refreshedToken, Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS))));
177186

@@ -182,12 +191,39 @@ void testGetAuthToken_TokenPastHalfwayPoint_Refreshes() throws Exception {
182191
}
183192
}
184193

194+
@Test
195+
void testGetAuthToken_EmptyAuthorizationData_ThrowsUnknownException() throws Exception {
196+
EcrClient mockEcrClient = mock(EcrClient.class);
197+
when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
198+
.thenReturn(GetAuthorizationTokenResponse.builder().authorizationData(Collections.emptyList()).build());
199+
200+
try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
201+
UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken);
202+
assertEquals(ERR_EMPTY_AUTH_DATA, ex.getMessage());
203+
}
204+
}
205+
206+
@Test
207+
void testGetAuthToken_InvalidTokenFormat_ThrowsUnknownException() throws Exception {
208+
EcrClient mockEcrClient = mock(EcrClient.class);
209+
AuthorizationData authData = AuthorizationData.builder()
210+
.authorizationToken(Base64.getEncoder().encodeToString("invalidtoken".getBytes()))
211+
.expiresAt(Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS))
212+
.build();
213+
when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
214+
.thenReturn(tokenResponse(authData));
215+
216+
try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
217+
UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken);
218+
assertEquals(ERR_INVALID_TOKEN_FORMAT, ex.getMessage());
219+
}
220+
}
221+
185222
@Test
186223
void testGetAuthToken_RefreshFails_FallsBackToCachedToken() throws Exception {
187224
String cachedToken = "still-valid-token";
188225
EcrClient mockEcrClient = mock(EcrClient.class);
189226
when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
190-
// First call primes cache with past-halfway token, second call simulates transient failure
191227
.thenReturn(tokenResponse(authDataWithExpiry(cachedToken, Instant.now().minusSeconds(1))))
192228
.thenThrow(awsException(AWS_ERROR_CODE_UNAVAILABLE));
193229

@@ -204,11 +240,35 @@ void testGetAuthToken_RefreshFails_NoCachedToken_ThrowsUnknownException() throws
204240
.thenThrow(awsException(AWS_ERROR_CODE_UNAVAILABLE));
205241

206242
try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
207-
UnknownException exception = assertThrows(UnknownException.class, registry::getAuthToken);
208-
assertEquals(ERR_FAILED_AUTH_TOKEN, exception.getMessage());
243+
UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken);
244+
assertEquals(ERR_FAILED_AUTH_TOKEN, ex.getMessage());
209245
}
210246
}
211247

248+
@Test
249+
void testClose_WithOciAndEcrClient_ClosesAll() throws Exception {
250+
EcrClient mockEcrClient = mock(EcrClient.class);
251+
AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient);
252+
registry.close();
253+
verify(mockEcrClient).close();
254+
}
255+
256+
@Test
257+
void testClose_WithNullEcrClient_NoError() throws Exception {
258+
AwsRegistry registry = new AwsRegistry();
259+
registry.close(); // should not throw
260+
}
261+
262+
static Stream<org.junit.jupiter.params.provider.Arguments> exceptionMappingProvider() { // NOSONAR - needed for @MethodSource resolution
263+
return Stream.of(
264+
arguments(new SubstrateSdkException("test"), SubstrateSdkException.class),
265+
arguments(awsException(AWS_ERROR_CODE_ACCESS_DENIED), UnAuthorizedException.class),
266+
arguments(awsException(AWS_ERROR_CODE_UNAVAILABLE), UnknownException.class),
267+
arguments(new IllegalArgumentException("invalid"), InvalidArgumentException.class),
268+
arguments(new RuntimeException("unknown"), UnknownException.class)
269+
);
270+
}
271+
212272
@ParameterizedTest
213273
@MethodSource("exceptionMappingProvider")
214274
void testGetException(Throwable input, Class<? extends SubstrateSdkException> expected) throws Exception {

registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ public ContainerRegistryClientBuilder withRegistryEndpoint(String registryEndpoi
7070
return this;
7171
}
7272

73+
/** Sets the region. Required for AWS; ignored by GCP. */
74+
public ContainerRegistryClientBuilder withRegion(String region) {
75+
this.registryBuilder.withRegion(region);
76+
return this;
77+
}
78+
7379
/** Sets a proxy endpoint override for HTTP requests. */
7480
public ContainerRegistryClientBuilder withProxyEndpoint(java.net.URI proxyEndpoint) {
7581
this.registryBuilder.withProxyEndpoint(proxyEndpoint);

0 commit comments

Comments
 (0)