diff --git a/registry/registry-aws/pom.xml b/registry/registry-aws/pom.xml index ca6d3544c..51efb2970 100644 --- a/registry/registry-aws/pom.xml +++ b/registry/registry-aws/pom.xml @@ -90,6 +90,12 @@ 5.12.1 test + + org.junit.jupiter + junit-jupiter-params + 5.12.1 + test + org.mockito mockito-junit-jupiter diff --git a/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptor.java b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptor.java new file mode 100644 index 000000000..0b1bb2845 --- /dev/null +++ b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptor.java @@ -0,0 +1,44 @@ +package com.salesforce.multicloudj.registry.aws; + +import java.net.URI; +import org.apache.http.HttpHeaders; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.HttpRequestInterceptor; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.protocol.HttpContext; + +/** + * HTTP request interceptor that strips Authorization headers when the request target is not the + * registry host. + * + *

This is necessary because AWS ECR redirects blob downloads to S3 pre-signed URLs, which + * already contain authentication in query parameters. Sending an Authorization header to S3 causes + * a 400 error ("Only one auth mechanism allowed"). + */ +public class AuthStrippingInterceptor implements HttpRequestInterceptor { + private final String registryHost; + + /** + * @param registryEndpoint the registry base URL + */ + public AuthStrippingInterceptor(String registryEndpoint) { + this.registryHost = extractHost(registryEndpoint); + } + + @Override + public void process(HttpRequest request, HttpContext context) { + HttpHost targetHost = (HttpHost) context.getAttribute(HttpClientContext.HTTP_TARGET_HOST); + if (targetHost == null) { + return; + } + if (!registryHost.equalsIgnoreCase(targetHost.getHostName())) { + request.removeHeaders(HttpHeaders.AUTHORIZATION); + } + } + + /** Extracts the hostname from a URL, stripping scheme, port, and path. */ + private static String extractHost(String url) { + return URI.create(url).getHost(); + } +} diff --git a/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AwsRegistry.java b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AwsRegistry.java new file mode 100644 index 000000000..bb2ae5237 --- /dev/null +++ b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AwsRegistry.java @@ -0,0 +1,223 @@ +package com.salesforce.multicloudj.registry.aws; + +import com.google.auto.service.AutoService; +import com.salesforce.multicloudj.common.aws.AwsConstants; +import com.salesforce.multicloudj.common.aws.CommonErrorCodeMapping; +import com.salesforce.multicloudj.common.aws.CredentialsProvider; +import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException; +import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException; +import com.salesforce.multicloudj.common.exceptions.UnknownException; +import com.salesforce.multicloudj.registry.driver.AbstractRegistry; +import com.salesforce.multicloudj.registry.driver.OciRegistryClient; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.List; +import org.apache.commons.lang3.StringUtils; +import org.apache.http.HttpRequestInterceptor; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.ecr.EcrClient; +import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenRequest; +import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenResponse; + +/** + * AWS Elastic Container Registry (ECR) implementation. + * + *

Authentication uses ECR's GetAuthorizationToken API which returns a Base64-encoded {@code + * AWS:} pair. The token is valid for 12 hours and is cached in memory. It is proactively + * refreshed at the halfway point of its validity window (6 hours). If a refresh fails and a cached + * tokenis available, the cached token is reused as a fallback rather than failing the request. + */ +@AutoService(AbstractRegistry.class) +public class AwsRegistry extends AbstractRegistry { + + private static final String AWS_AUTH_USERNAME = "AWS"; + + /** Lock for thread-safe lazy initialization of ECR client. */ + private final Object ecrClientLock = new Object(); + + /** Lock for thread-safe token refresh. */ + private final Object tokenLock = new Object(); + + private final OciRegistryClient ociClient; + + /** Lazily initialized ECR client with double-checked locking. */ + private volatile EcrClient ecrClient; + + private volatile String cachedAuthToken; + private volatile long tokenRequestedAt; + private volatile long tokenExpirationTime; + + public AwsRegistry() { + this(new Builder()); + } + + public AwsRegistry(Builder builder) { + this(builder, null); + } + + /** + * Creates AwsRegistry with specified EcrClient. + * + * @param builder the builder with configuration + * @param ecrClient the ECR client to use (null to create default) + */ + public AwsRegistry(Builder builder, EcrClient ecrClient) { + super(builder); + this.ecrClient = ecrClient; + this.ociClient = + registryEndpoint != null ? new OciRegistryClient(registryEndpoint, this) : null; + } + + @Override + public Builder builder() { + return new Builder(); + } + + @Override + protected OciRegistryClient getOciClient() { + return ociClient; + } + + @Override + public String getAuthUsername() { + return AWS_AUTH_USERNAME; + } + + @Override + public String getAuthToken() { + if (cachedAuthToken == null || isPastRefreshPoint()) { + synchronized (tokenLock) { + if (cachedAuthToken == null || isPastRefreshPoint()) { + refreshAuthToken(); + } + } + } + return cachedAuthToken; + } + + @Override + protected List getInterceptors() { + return List.of(new AuthStrippingInterceptor(registryEndpoint)); + } + + /** + * Returns true if the current time is past the halfway point of the token's validity window, at + * which point a proactive refresh is triggered i.e. treats tokens as invalid after 50% of their + * lifetime. + */ + private boolean isPastRefreshPoint() { + long halfwayPoint = tokenRequestedAt + (tokenExpirationTime - tokenRequestedAt) / 2; + return System.currentTimeMillis() >= halfwayPoint; + } + + /** Returns the ECR client, initializing lazily with double-checked locking. */ + private EcrClient getOrCreateEcrClient() { + if (ecrClient == null) { + synchronized (ecrClientLock) { + if (ecrClient == null) { + ecrClient = createEcrClient(); + } + } + } + return ecrClient; + } + + private EcrClient createEcrClient() { + Region awsRegion = Region.of(region); + AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create(); + if (credentialsOverrider != null) { + AwsCredentialsProvider overrideProvider = + CredentialsProvider.getCredentialsProvider(credentialsOverrider, awsRegion); + if (overrideProvider != null) { + credentialsProvider = overrideProvider; + } + } + return EcrClient.builder().region(awsRegion).credentialsProvider(credentialsProvider).build(); + } + + /** + * Fetches a fresh ECR authorization token and updates the cache. On {@link AwsServiceException}, + * falls back to the existing cached token if one is available. + */ + private void refreshAuthToken() { + try { + GetAuthorizationTokenResponse response = + getOrCreateEcrClient() + .getAuthorizationToken(GetAuthorizationTokenRequest.builder().build()); + + if (response.authorizationData().isEmpty()) { + throw new UnknownException("ECR returned empty authorization data"); + } + + // ECR token is Base64-encoded "AWS:"; extract the password portion + String encodedToken = response.authorizationData().get(0).authorizationToken(); + String decodedToken = + new String(Base64.getDecoder().decode(encodedToken), StandardCharsets.UTF_8); + String[] parts = decodedToken.split(":", 2); + if (parts.length != 2) { + throw new UnknownException("Invalid ECR authorization token format"); + } + + cachedAuthToken = parts[1]; + tokenRequestedAt = System.currentTimeMillis(); + tokenExpirationTime = response.authorizationData().get(0).expiresAt().toEpochMilli(); + } catch (AwsServiceException e) { + if (cachedAuthToken != null) { + return; + } + throw new UnknownException("Failed to get ECR authorization token", e); + } + } + + @Override + public Class getException(Throwable t) { + if (t instanceof SubstrateSdkException) { + return (Class) t.getClass(); + } else if (t instanceof AwsServiceException) { + AwsServiceException awsException = (AwsServiceException) t; + String errorCode = awsException.awsErrorDetails().errorCode(); + Class mappedException = + CommonErrorCodeMapping.get().get(errorCode); + return mappedException != null ? mappedException : UnknownException.class; + } else if (t instanceof IllegalArgumentException) { + return InvalidArgumentException.class; + } + return UnknownException.class; + } + + @Override + public void close() throws Exception { + if (ociClient != null) { + ociClient.close(); + } + if (ecrClient != null) { + ecrClient.close(); + } + } + + public static final class Builder extends AbstractRegistry.Builder { + + public Builder() { + providerId(AwsConstants.PROVIDER_ID); + } + + @Override + public Builder self() { + return this; + } + + @Override + public AwsRegistry build() { + if (StringUtils.isBlank(registryEndpoint)) { + throw new InvalidArgumentException("Registry endpoint is required for AWS ECR"); + } + if (StringUtils.isBlank(region)) { + throw new InvalidArgumentException("AWS region is required"); + } + return new AwsRegistry(this); + } + } +} diff --git a/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptorTest.java b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptorTest.java new file mode 100644 index 000000000..64c32549c --- /dev/null +++ b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptorTest.java @@ -0,0 +1,122 @@ +package com.salesforce.multicloudj.registry.aws; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.http.HttpHeaders; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.http.message.BasicHttpRequest; +import org.apache.http.protocol.BasicHttpContext; +import org.apache.http.protocol.HttpContext; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class AuthStrippingInterceptorTest { + + private static final String REGISTRY_ENDPOINT = + "https://123456789012.dkr.ecr.us-east-1.amazonaws.com"; + private static final String REGISTRY_HOST = "123456789012.dkr.ecr.us-east-1.amazonaws.com"; + private static final String EXAMPLE_ENDPOINT_HOST = "registry.example.com"; + private static final String S3_HOST = "s3.amazonaws.com"; + private static final String GET = "GET"; + private static final String BEARER_TOKEN = "Bearer token123"; + private static final String PATH_BLOB = "/v2/repo/blobs/sha256:abc"; + + private AuthStrippingInterceptor interceptor; + + @BeforeEach + void setUp() { + interceptor = new AuthStrippingInterceptor(REGISTRY_ENDPOINT); + } + + @Test + void testProcess_KeepsAuthHeader_WhenTargetIsRegistryHost() { + HttpRequest request = requestWithAuth(PATH_BLOB); + HttpContext context = contextWithHost(REGISTRY_HOST); + + interceptor.process(request, context); + + assertTrue(request.containsHeader(HttpHeaders.AUTHORIZATION)); + } + + @ParameterizedTest + @ValueSource( + strings = { + "s3.amazonaws.com", + "prod-us-east-1-starport-layer-bucket.s3.us-east-1.amazonaws.com" + }) + void testProcess_StripsAuthHeader_WhenTargetIsExternalHost(String externalHost) { + HttpRequest request = requestWithAuth(PATH_BLOB); + HttpContext context = contextWithHost(externalHost); + + interceptor.process(request, context); + + assertFalse(request.containsHeader(HttpHeaders.AUTHORIZATION)); + } + + @Test + void testProcess_NoOp_WhenTargetHostIsNull() { + HttpRequest request = requestWithAuth(PATH_BLOB); + HttpContext context = new BasicHttpContext(); + + interceptor.process(request, context); + + assertTrue(request.containsHeader(HttpHeaders.AUTHORIZATION)); + } + + @Test + void testProcess_CaseInsensitiveHostComparison() { + HttpRequest request = requestWithAuth(PATH_BLOB); + HttpContext context = contextWithHost(REGISTRY_HOST.toUpperCase()); + + interceptor.process(request, context); + + assertTrue(request.containsHeader(HttpHeaders.AUTHORIZATION)); + } + + @Test + void testProcess_NoAuthHeader_NoError() { + HttpRequest request = new BasicHttpRequest(GET, PATH_BLOB); + HttpContext context = contextWithHost(S3_HOST); + + interceptor.process(request, context); + + assertFalse(request.containsHeader(HttpHeaders.AUTHORIZATION)); + } + + @ParameterizedTest + @ValueSource( + strings = { + "https://registry.example.com/", + "https://registry.example.com:443", + "https://registry.example.com:443/v2/", + "http://registry.example.com" + }) + void testExtractHost_CorrectlyParsesVariousUrlFormats(String endpoint) { + AuthStrippingInterceptor localInterceptor = new AuthStrippingInterceptor(endpoint); + HttpRequest request = requestWithAuth("/v2/"); + HttpContext context = contextWithHost(EXAMPLE_ENDPOINT_HOST); + + localInterceptor.process(request, context); + + assertTrue( + request.containsHeader(HttpHeaders.AUTHORIZATION), + "Should correctly extract host from: " + endpoint); + } + + private HttpRequest requestWithAuth(String path) { + HttpRequest request = new BasicHttpRequest(GET, path); + request.setHeader(HttpHeaders.AUTHORIZATION, BEARER_TOKEN); + return request; + } + + private HttpContext contextWithHost(String host) { + HttpContext context = new BasicHttpContext(); + context.setAttribute(HttpClientContext.HTTP_TARGET_HOST, new HttpHost(host, 443, "https")); + return context; + } +} diff --git a/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AwsRegistryTest.java b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AwsRegistryTest.java new file mode 100644 index 000000000..5aa3490d5 --- /dev/null +++ b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AwsRegistryTest.java @@ -0,0 +1,298 @@ +package com.salesforce.multicloudj.registry.aws; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.params.provider.Arguments.arguments; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException; +import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException; +import com.salesforce.multicloudj.common.exceptions.UnAuthorizedException; +import com.salesforce.multicloudj.common.exceptions.UnknownException; +import java.time.Instant; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.stream.Stream; +import org.apache.http.HttpRequestInterceptor; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.NullAndEmptySource; +import org.junit.jupiter.params.provider.ValueSource; +import software.amazon.awssdk.awscore.exception.AwsErrorDetails; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.services.ecr.EcrClient; +import software.amazon.awssdk.services.ecr.model.AuthorizationData; +import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenRequest; +import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenResponse; + +class AwsRegistryTest { + + private static final String TEST_REGION = "us-east-1"; + private static final String TEST_REGISTRY_ENDPOINT = + "https://123456789012.dkr.ecr.us-east-1.amazonaws.com"; + private static final String PROVIDER_ID = "aws"; + private static final String AUTH_USERNAME = "AWS"; + private static final String TOKEN_PREFIX = "AWS:"; + private static final String ERR_EMPTY_AUTH_DATA = "ECR returned empty authorization data"; + private static final String ERR_INVALID_TOKEN_FORMAT = "Invalid ECR authorization token format"; + private static final String ERR_FAILED_AUTH_TOKEN = "Failed to get ECR authorization token"; + private static final String AWS_ERROR_CODE_ACCESS_DENIED = "AccessDenied"; + private static final String AWS_ERROR_CODE_UNAVAILABLE = "ServiceUnavailableException"; + private static final long TOKEN_VALIDITY_SECONDS = 43200; // 12 hours + + @FunctionalInterface + interface RegistryTestAction { + void execute(AwsRegistry registry) throws Exception; + } + + private AwsRegistry createRegistryWithMockEcrClient(EcrClient mockEcrClient) { + AwsRegistry.Builder builder = + new AwsRegistry.Builder() + .withRegion(TEST_REGION) + .withRegistryEndpoint(TEST_REGISTRY_ENDPOINT); + return new AwsRegistry(builder, mockEcrClient); + } + + private String encodeToken(String token) { + return Base64.getEncoder().encodeToString((TOKEN_PREFIX + token).getBytes()); + } + + private AuthorizationData authDataWithExpiry(String token, Instant expiresAt) { + return AuthorizationData.builder() + .authorizationToken(encodeToken(token)) + .expiresAt(expiresAt) + .build(); + } + + private GetAuthorizationTokenResponse tokenResponse(AuthorizationData authData) { + return GetAuthorizationTokenResponse.builder() + .authorizationData(Collections.singletonList(authData)) + .build(); + } + + private static AwsServiceException awsException(String errorCode) { + return AwsServiceException.builder() + .awsErrorDetails(AwsErrorDetails.builder().errorCode(errorCode).build()) + .build(); + } + + private void withMockedRegistry(RegistryTestAction action) throws Exception { + EcrClient mockEcrClient = mock(EcrClient.class); + try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) { + action.execute(registry); + } + } + + @Test + void testNoArgConstructor_CreatesInstanceWithDefaultBuilder() { + AwsRegistry registry = new AwsRegistry(); + assertNotNull(registry); + assertNull(registry.getOciClient()); + } + + @Test + void testConstructor_WithBuilder_InitialisesFields() throws Exception { + withMockedRegistry( + registry -> { + assertEquals(PROVIDER_ID, registry.getProviderId()); + assertEquals(AUTH_USERNAME, registry.getAuthUsername()); + assertNotNull(registry.getOciClient()); + }); + } + + @Test + void testBuilder_InstanceMethod_ReturnsNewBuilder() throws Exception { + withMockedRegistry(registry -> assertNotNull(registry.builder())); + } + + @Test + void testBuilder_WithRegion_GetRegion_RoundTrip() { + AwsRegistry.Builder builder = new AwsRegistry.Builder().withRegion(TEST_REGION); + assertEquals(TEST_REGION, builder.getRegion()); + } + + @ParameterizedTest + @NullAndEmptySource + @ValueSource(strings = {" "}) + void testBuilder_MissingEndpoint_ThrowsInvalidArgumentException(String endpoint) { + assertThrows( + InvalidArgumentException.class, + () -> + new AwsRegistry.Builder() + .withRegion(TEST_REGION) + .withRegistryEndpoint(endpoint) + .build()); + } + + @ParameterizedTest + @NullAndEmptySource + @ValueSource(strings = {" "}) + void testBuilder_MissingRegion_ThrowsInvalidArgumentException(String region) { + assertThrows( + InvalidArgumentException.class, + () -> + new AwsRegistry.Builder() + .withRegion(region) + .withRegistryEndpoint(TEST_REGISTRY_ENDPOINT) + .build()); + } + + @Test + void testGetInterceptors_ReturnsAuthStrippingInterceptor() throws Exception { + withMockedRegistry( + registry -> { + List interceptors = registry.getInterceptors(); + assertFalse(interceptors.isEmpty()); + assertInstanceOf(AuthStrippingInterceptor.class, interceptors.get(0)); + }); + } + + @Test + void testGetOciClient_ReturnsNonNull_WhenEndpointProvided() throws Exception { + withMockedRegistry(registry -> assertNotNull(registry.getOciClient())); + } + + @Test + void testGetOciClient_ReturnsNull_WhenNoEndpoint() { + AwsRegistry registry = new AwsRegistry(); + assertNull(registry.getOciClient()); + } + + @Test + void testGetAuthToken_TokenCachedWithinHalfwayWindow_NoRefresh() throws Exception { + String expectedToken = "cached-token"; + EcrClient mockEcrClient = mock(EcrClient.class); + when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class))) + .thenReturn( + tokenResponse( + authDataWithExpiry( + expectedToken, Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS)))); + + try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) { + assertEquals(expectedToken, registry.getAuthToken()); + assertEquals(expectedToken, registry.getAuthToken()); // second call should use cache + verify(mockEcrClient, times(1)) + .getAuthorizationToken(any(GetAuthorizationTokenRequest.class)); + } + } + + @Test + void testGetAuthToken_TokenPastHalfwayPoint_Refreshes() throws Exception { + String firstToken = "first-token"; + String refreshedToken = "refreshed-token"; + EcrClient mockEcrClient = mock(EcrClient.class); + when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class))) + .thenReturn(tokenResponse(authDataWithExpiry(firstToken, Instant.now().minusSeconds(1)))) + .thenReturn( + tokenResponse( + authDataWithExpiry( + refreshedToken, Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS)))); + + try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) { + registry.getAuthToken(); // primes cache with already-past-halfway token + assertEquals(refreshedToken, registry.getAuthToken()); // triggers refresh + verify(mockEcrClient, times(2)) + .getAuthorizationToken(any(GetAuthorizationTokenRequest.class)); + } + } + + @Test + void testGetAuthToken_EmptyAuthorizationData_ThrowsUnknownException() throws Exception { + EcrClient mockEcrClient = mock(EcrClient.class); + when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class))) + .thenReturn( + GetAuthorizationTokenResponse.builder() + .authorizationData(Collections.emptyList()) + .build()); + + try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) { + UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken); + assertEquals(ERR_EMPTY_AUTH_DATA, ex.getMessage()); + } + } + + @Test + void testGetAuthToken_InvalidTokenFormat_ThrowsUnknownException() throws Exception { + EcrClient mockEcrClient = mock(EcrClient.class); + AuthorizationData authData = + AuthorizationData.builder() + .authorizationToken(Base64.getEncoder().encodeToString("invalidtoken".getBytes())) + .expiresAt(Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS)) + .build(); + when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class))) + .thenReturn(tokenResponse(authData)); + + try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) { + UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken); + assertEquals(ERR_INVALID_TOKEN_FORMAT, ex.getMessage()); + } + } + + @Test + void testGetAuthToken_RefreshFails_FallsBackToCachedToken() throws Exception { + String cachedToken = "still-valid-token"; + EcrClient mockEcrClient = mock(EcrClient.class); + when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class))) + .thenReturn(tokenResponse(authDataWithExpiry(cachedToken, Instant.now().minusSeconds(1)))) + .thenThrow(awsException(AWS_ERROR_CODE_UNAVAILABLE)); + + try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) { + registry.getAuthToken(); // primes cache + assertEquals(cachedToken, registry.getAuthToken()); // falls back to cached token + } + } + + @Test + void testGetAuthToken_RefreshFails_NoCachedToken_ThrowsUnknownException() throws Exception { + EcrClient mockEcrClient = mock(EcrClient.class); + when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class))) + .thenThrow(awsException(AWS_ERROR_CODE_UNAVAILABLE)); + + try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) { + UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken); + assertEquals(ERR_FAILED_AUTH_TOKEN, ex.getMessage()); + } + } + + @Test + void testClose_WithOciAndEcrClient_ClosesAll() throws Exception { + EcrClient mockEcrClient = mock(EcrClient.class); + AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient); + registry.close(); + verify(mockEcrClient).close(); + } + + @Test + void testClose_WithNullEcrClient_NoError() throws Exception { + AwsRegistry registry = new AwsRegistry(); + registry.close(); // should not throw + } + + static Stream + exceptionMappingProvider() { // NOSONAR - needed for @MethodSource resolution + return Stream.of( + arguments(new SubstrateSdkException("test"), SubstrateSdkException.class), + arguments(awsException(AWS_ERROR_CODE_ACCESS_DENIED), UnAuthorizedException.class), + arguments(awsException(AWS_ERROR_CODE_UNAVAILABLE), UnknownException.class), + arguments(new IllegalArgumentException("invalid"), InvalidArgumentException.class), + arguments(new RuntimeException("unknown"), UnknownException.class)); + } + + @ParameterizedTest + @MethodSource("exceptionMappingProvider") + void testGetException(Throwable input, Class expected) + throws Exception { + withMockedRegistry(registry -> assertEquals(expected, registry.getException(input))); + } +} diff --git a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java index b52d75236..94f8cbd88 100644 --- a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java +++ b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java @@ -5,6 +5,7 @@ import com.salesforce.multicloudj.registry.driver.AbstractRegistry; import com.salesforce.multicloudj.registry.model.Image; import com.salesforce.multicloudj.registry.model.Platform; +import com.salesforce.multicloudj.sts.model.CredentialsOverrider; import java.io.InputStream; /** @@ -73,6 +74,12 @@ public ContainerRegistryClientBuilder withRegistryEndpoint(String registryEndpoi return this; } + /** Sets the region. Required for AWS; ignored by GCP. */ + public ContainerRegistryClientBuilder withRegion(String region) { + this.registryBuilder.withRegion(region); + return this; + } + /** Sets a proxy endpoint override for HTTP requests. */ public ContainerRegistryClientBuilder withProxyEndpoint(java.net.URI proxyEndpoint) { this.registryBuilder.withProxyEndpoint(proxyEndpoint); @@ -87,7 +94,7 @@ public ContainerRegistryClientBuilder withPlatform(Platform platform) { /** Sets credentials overrider for authentication. */ public ContainerRegistryClientBuilder withCredentialsOverrider( - com.salesforce.multicloudj.sts.model.CredentialsOverrider credentialsOverrider) { + CredentialsOverrider credentialsOverrider) { this.registryBuilder.withCredentialsOverrider(credentialsOverrider); return this; } diff --git a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java index 487f61735..33c1b80a1 100644 --- a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java +++ b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java @@ -11,10 +11,12 @@ import com.salesforce.multicloudj.sts.model.CredentialsOverrider; import java.io.InputStream; import java.net.URI; +import java.util.Collections; import java.util.List; import java.util.stream.Collectors; import lombok.Getter; import org.apache.commons.lang3.StringUtils; +import org.apache.http.HttpRequestInterceptor; /** Abstract registry driver. Each cloud implements authentication and OCI client. */ public abstract class AbstractRegistry implements Provider, AutoCloseable, AuthProvider { @@ -23,6 +25,7 @@ public abstract class AbstractRegistry implements Provider, AutoCloseable, AuthP protected final String providerId; protected final String registryEndpoint; + protected final String region; protected final URI proxyEndpoint; protected final CredentialsOverrider credentialsOverrider; @Getter protected final Platform targetPlatform; @@ -30,6 +33,7 @@ public abstract class AbstractRegistry implements Provider, AutoCloseable, AuthP protected AbstractRegistry(Builder builder) { this.providerId = builder.getProviderId(); this.registryEndpoint = builder.getRegistryEndpoint(); + this.region = builder.getRegion(); this.proxyEndpoint = builder.getProxyEndpoint(); this.credentialsOverrider = builder.getCredentialsOverrider(); this.targetPlatform = builder.getPlatform() != null ? builder.getPlatform() : Platform.DEFAULT; @@ -52,6 +56,15 @@ public String getProviderId() { /** Returns the OCI client for this registry. */ protected abstract OciRegistryClient getOciClient(); + /** + * Returns the list of HTTP request interceptors to be registered with the HTTP client. Override + * this method in provider-specific subclasses to add custom interceptors. By default, returns an + * empty list (no interceptors). + */ + protected List getInterceptors() { + return Collections.emptyList(); + } + /** * Pulls an image from the registry (unified OCI flow). * @@ -110,7 +123,7 @@ public Image pull(String imageRef) { * *

* @@ -205,10 +218,16 @@ public abstract static class BuilderThis is necessary because registries like AWS ECR redirect blob downloads to S3 pre-signed - * URLs, which already contain authentication in query parameters. Sending an Authorization header - * to S3 causes a 400 error ("Only one auth mechanism allowed"). - */ - private static class AuthStrippingInterceptor implements HttpRequestInterceptor { - private final String registryHost; - - AuthStrippingInterceptor(String registryEndpoint) { - this.registryHost = extractHost(registryEndpoint); - } - - @Override - public void process(HttpRequest request, HttpContext context) { - // Get the target host from the context (works for both initial requests and redirects) - var targetHost = - (org.apache.http.HttpHost) context.getAttribute(HttpClientContext.HTTP_TARGET_HOST); - if (targetHost != null) { - String requestHost = targetHost.getHostName(); - // Strip Authorization header if request is not to the registry host - if (!registryHost.equalsIgnoreCase(requestHost)) { - request.removeHeaders(HttpHeaders.AUTHORIZATION); - } - } - } - - private static String extractHost(String url) { - String host = url.replaceFirst("^https?://", ""); - int slashIndex = host.indexOf('/'); - if (slashIndex > 0) { - host = host.substring(0, slashIndex); + HttpClientBuilder builder = HttpClients.custom(); + for (HttpRequestInterceptor interceptor : registry.getInterceptors()) { + builder.addInterceptorLast(interceptor); } - int colonIndex = host.indexOf(':'); - if (colonIndex > 0) { - host = host.substring(0, colonIndex); - } - return host; + this.httpClient = builder.build(); } + this.tokenExchange = new BearerTokenExchange(this.httpClient); } /** diff --git a/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java b/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java index f199d4baf..5a2adb60a 100644 --- a/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java +++ b/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java @@ -24,6 +24,7 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.Base64; +import java.util.function.Consumer; import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.HttpStatus; @@ -40,6 +41,7 @@ import org.mockito.MockedConstruction; import org.mockito.MockedStatic; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; /** * Unit tests for OciRegistryClient. Tests authentication header generation for Basic, Bearer, and @@ -50,7 +52,7 @@ public class OciRegistryClientTest { private static final String REGISTRY_ENDPOINT = "https://test-registry.example.com"; private static final String REPOSITORY = "test-repo/test-image"; - @Mock private AuthProvider mockAuthProvider; + @Mock private AbstractRegistry mockAuthProvider; private AutoCloseable mocks; @@ -1010,8 +1012,7 @@ private MockedStatic mockAuthChallenge() { * @return mocked CloseableHttpClient */ private CloseableHttpClient createMockHttpClientWithExecuteAnswer( - String blobContent, - java.util.function.Consumer requestAssertion) { + String blobContent, Consumer requestAssertion) { CloseableHttpClient mockHttpClient = mock(CloseableHttpClient.class); CloseableHttpResponse mockResponse = mock(CloseableHttpResponse.class); StatusLine mockStatusLine = mock(StatusLine.class); @@ -1072,6 +1073,6 @@ private CloseableHttpClient createMockHttpClientForBlob( @FunctionalInterface interface ManifestAssertion { - void assertManifest(com.salesforce.multicloudj.registry.model.Manifest manifest); + void assertManifest(Manifest manifest); } }