org.mockito
mockito-junit-jupiter
diff --git a/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptor.java b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptor.java
new file mode 100644
index 000000000..0b1bb2845
--- /dev/null
+++ b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptor.java
@@ -0,0 +1,44 @@
+package com.salesforce.multicloudj.registry.aws;
+
+import java.net.URI;
+import org.apache.http.HttpHeaders;
+import org.apache.http.HttpHost;
+import org.apache.http.HttpRequest;
+import org.apache.http.HttpRequestInterceptor;
+import org.apache.http.client.protocol.HttpClientContext;
+import org.apache.http.protocol.HttpContext;
+
+/**
+ * HTTP request interceptor that strips Authorization headers when the request target is not the
+ * registry host.
+ *
+ * This is necessary because AWS ECR redirects blob downloads to S3 pre-signed URLs, which
+ * already contain authentication in query parameters. Sending an Authorization header to S3 causes
+ * a 400 error ("Only one auth mechanism allowed").
+ */
+public class AuthStrippingInterceptor implements HttpRequestInterceptor {
+ private final String registryHost;
+
+ /**
+ * @param registryEndpoint the registry base URL
+ */
+ public AuthStrippingInterceptor(String registryEndpoint) {
+ this.registryHost = extractHost(registryEndpoint);
+ }
+
+ @Override
+ public void process(HttpRequest request, HttpContext context) {
+ HttpHost targetHost = (HttpHost) context.getAttribute(HttpClientContext.HTTP_TARGET_HOST);
+ if (targetHost == null) {
+ return;
+ }
+ if (!registryHost.equalsIgnoreCase(targetHost.getHostName())) {
+ request.removeHeaders(HttpHeaders.AUTHORIZATION);
+ }
+ }
+
+ /** Extracts the hostname from a URL, stripping scheme, port, and path. */
+ private static String extractHost(String url) {
+ return URI.create(url).getHost();
+ }
+}
diff --git a/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AwsRegistry.java b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AwsRegistry.java
new file mode 100644
index 000000000..bb2ae5237
--- /dev/null
+++ b/registry/registry-aws/src/main/java/com/salesforce/multicloudj/registry/aws/AwsRegistry.java
@@ -0,0 +1,223 @@
+package com.salesforce.multicloudj.registry.aws;
+
+import com.google.auto.service.AutoService;
+import com.salesforce.multicloudj.common.aws.AwsConstants;
+import com.salesforce.multicloudj.common.aws.CommonErrorCodeMapping;
+import com.salesforce.multicloudj.common.aws.CredentialsProvider;
+import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException;
+import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException;
+import com.salesforce.multicloudj.common.exceptions.UnknownException;
+import com.salesforce.multicloudj.registry.driver.AbstractRegistry;
+import com.salesforce.multicloudj.registry.driver.OciRegistryClient;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.List;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.http.HttpRequestInterceptor;
+import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
+import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
+import software.amazon.awssdk.awscore.exception.AwsServiceException;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.ecr.EcrClient;
+import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenRequest;
+import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenResponse;
+
+/**
+ * AWS Elastic Container Registry (ECR) implementation.
+ *
+ *
Authentication uses ECR's GetAuthorizationToken API which returns a Base64-encoded {@code
+ * AWS:} pair. The token is valid for 12 hours and is cached in memory. It is proactively
+ * refreshed at the halfway point of its validity window (6 hours). If a refresh fails and a cached
+ * tokenis available, the cached token is reused as a fallback rather than failing the request.
+ */
+@AutoService(AbstractRegistry.class)
+public class AwsRegistry extends AbstractRegistry {
+
+ private static final String AWS_AUTH_USERNAME = "AWS";
+
+ /** Lock for thread-safe lazy initialization of ECR client. */
+ private final Object ecrClientLock = new Object();
+
+ /** Lock for thread-safe token refresh. */
+ private final Object tokenLock = new Object();
+
+ private final OciRegistryClient ociClient;
+
+ /** Lazily initialized ECR client with double-checked locking. */
+ private volatile EcrClient ecrClient;
+
+ private volatile String cachedAuthToken;
+ private volatile long tokenRequestedAt;
+ private volatile long tokenExpirationTime;
+
+ public AwsRegistry() {
+ this(new Builder());
+ }
+
+ public AwsRegistry(Builder builder) {
+ this(builder, null);
+ }
+
+ /**
+ * Creates AwsRegistry with specified EcrClient.
+ *
+ * @param builder the builder with configuration
+ * @param ecrClient the ECR client to use (null to create default)
+ */
+ public AwsRegistry(Builder builder, EcrClient ecrClient) {
+ super(builder);
+ this.ecrClient = ecrClient;
+ this.ociClient =
+ registryEndpoint != null ? new OciRegistryClient(registryEndpoint, this) : null;
+ }
+
+ @Override
+ public Builder builder() {
+ return new Builder();
+ }
+
+ @Override
+ protected OciRegistryClient getOciClient() {
+ return ociClient;
+ }
+
+ @Override
+ public String getAuthUsername() {
+ return AWS_AUTH_USERNAME;
+ }
+
+ @Override
+ public String getAuthToken() {
+ if (cachedAuthToken == null || isPastRefreshPoint()) {
+ synchronized (tokenLock) {
+ if (cachedAuthToken == null || isPastRefreshPoint()) {
+ refreshAuthToken();
+ }
+ }
+ }
+ return cachedAuthToken;
+ }
+
+ @Override
+ protected List getInterceptors() {
+ return List.of(new AuthStrippingInterceptor(registryEndpoint));
+ }
+
+ /**
+ * Returns true if the current time is past the halfway point of the token's validity window, at
+ * which point a proactive refresh is triggered i.e. treats tokens as invalid after 50% of their
+ * lifetime.
+ */
+ private boolean isPastRefreshPoint() {
+ long halfwayPoint = tokenRequestedAt + (tokenExpirationTime - tokenRequestedAt) / 2;
+ return System.currentTimeMillis() >= halfwayPoint;
+ }
+
+ /** Returns the ECR client, initializing lazily with double-checked locking. */
+ private EcrClient getOrCreateEcrClient() {
+ if (ecrClient == null) {
+ synchronized (ecrClientLock) {
+ if (ecrClient == null) {
+ ecrClient = createEcrClient();
+ }
+ }
+ }
+ return ecrClient;
+ }
+
+ private EcrClient createEcrClient() {
+ Region awsRegion = Region.of(region);
+ AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
+ if (credentialsOverrider != null) {
+ AwsCredentialsProvider overrideProvider =
+ CredentialsProvider.getCredentialsProvider(credentialsOverrider, awsRegion);
+ if (overrideProvider != null) {
+ credentialsProvider = overrideProvider;
+ }
+ }
+ return EcrClient.builder().region(awsRegion).credentialsProvider(credentialsProvider).build();
+ }
+
+ /**
+ * Fetches a fresh ECR authorization token and updates the cache. On {@link AwsServiceException},
+ * falls back to the existing cached token if one is available.
+ */
+ private void refreshAuthToken() {
+ try {
+ GetAuthorizationTokenResponse response =
+ getOrCreateEcrClient()
+ .getAuthorizationToken(GetAuthorizationTokenRequest.builder().build());
+
+ if (response.authorizationData().isEmpty()) {
+ throw new UnknownException("ECR returned empty authorization data");
+ }
+
+ // ECR token is Base64-encoded "AWS:"; extract the password portion
+ String encodedToken = response.authorizationData().get(0).authorizationToken();
+ String decodedToken =
+ new String(Base64.getDecoder().decode(encodedToken), StandardCharsets.UTF_8);
+ String[] parts = decodedToken.split(":", 2);
+ if (parts.length != 2) {
+ throw new UnknownException("Invalid ECR authorization token format");
+ }
+
+ cachedAuthToken = parts[1];
+ tokenRequestedAt = System.currentTimeMillis();
+ tokenExpirationTime = response.authorizationData().get(0).expiresAt().toEpochMilli();
+ } catch (AwsServiceException e) {
+ if (cachedAuthToken != null) {
+ return;
+ }
+ throw new UnknownException("Failed to get ECR authorization token", e);
+ }
+ }
+
+ @Override
+ public Class extends SubstrateSdkException> getException(Throwable t) {
+ if (t instanceof SubstrateSdkException) {
+ return (Class extends SubstrateSdkException>) t.getClass();
+ } else if (t instanceof AwsServiceException) {
+ AwsServiceException awsException = (AwsServiceException) t;
+ String errorCode = awsException.awsErrorDetails().errorCode();
+ Class extends SubstrateSdkException> mappedException =
+ CommonErrorCodeMapping.get().get(errorCode);
+ return mappedException != null ? mappedException : UnknownException.class;
+ } else if (t instanceof IllegalArgumentException) {
+ return InvalidArgumentException.class;
+ }
+ return UnknownException.class;
+ }
+
+ @Override
+ public void close() throws Exception {
+ if (ociClient != null) {
+ ociClient.close();
+ }
+ if (ecrClient != null) {
+ ecrClient.close();
+ }
+ }
+
+ public static final class Builder extends AbstractRegistry.Builder {
+
+ public Builder() {
+ providerId(AwsConstants.PROVIDER_ID);
+ }
+
+ @Override
+ public Builder self() {
+ return this;
+ }
+
+ @Override
+ public AwsRegistry build() {
+ if (StringUtils.isBlank(registryEndpoint)) {
+ throw new InvalidArgumentException("Registry endpoint is required for AWS ECR");
+ }
+ if (StringUtils.isBlank(region)) {
+ throw new InvalidArgumentException("AWS region is required");
+ }
+ return new AwsRegistry(this);
+ }
+ }
+}
diff --git a/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptorTest.java b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptorTest.java
new file mode 100644
index 000000000..64c32549c
--- /dev/null
+++ b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AuthStrippingInterceptorTest.java
@@ -0,0 +1,122 @@
+package com.salesforce.multicloudj.registry.aws;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import org.apache.http.HttpHeaders;
+import org.apache.http.HttpHost;
+import org.apache.http.HttpRequest;
+import org.apache.http.client.protocol.HttpClientContext;
+import org.apache.http.message.BasicHttpRequest;
+import org.apache.http.protocol.BasicHttpContext;
+import org.apache.http.protocol.HttpContext;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+class AuthStrippingInterceptorTest {
+
+ private static final String REGISTRY_ENDPOINT =
+ "https://123456789012.dkr.ecr.us-east-1.amazonaws.com";
+ private static final String REGISTRY_HOST = "123456789012.dkr.ecr.us-east-1.amazonaws.com";
+ private static final String EXAMPLE_ENDPOINT_HOST = "registry.example.com";
+ private static final String S3_HOST = "s3.amazonaws.com";
+ private static final String GET = "GET";
+ private static final String BEARER_TOKEN = "Bearer token123";
+ private static final String PATH_BLOB = "/v2/repo/blobs/sha256:abc";
+
+ private AuthStrippingInterceptor interceptor;
+
+ @BeforeEach
+ void setUp() {
+ interceptor = new AuthStrippingInterceptor(REGISTRY_ENDPOINT);
+ }
+
+ @Test
+ void testProcess_KeepsAuthHeader_WhenTargetIsRegistryHost() {
+ HttpRequest request = requestWithAuth(PATH_BLOB);
+ HttpContext context = contextWithHost(REGISTRY_HOST);
+
+ interceptor.process(request, context);
+
+ assertTrue(request.containsHeader(HttpHeaders.AUTHORIZATION));
+ }
+
+ @ParameterizedTest
+ @ValueSource(
+ strings = {
+ "s3.amazonaws.com",
+ "prod-us-east-1-starport-layer-bucket.s3.us-east-1.amazonaws.com"
+ })
+ void testProcess_StripsAuthHeader_WhenTargetIsExternalHost(String externalHost) {
+ HttpRequest request = requestWithAuth(PATH_BLOB);
+ HttpContext context = contextWithHost(externalHost);
+
+ interceptor.process(request, context);
+
+ assertFalse(request.containsHeader(HttpHeaders.AUTHORIZATION));
+ }
+
+ @Test
+ void testProcess_NoOp_WhenTargetHostIsNull() {
+ HttpRequest request = requestWithAuth(PATH_BLOB);
+ HttpContext context = new BasicHttpContext();
+
+ interceptor.process(request, context);
+
+ assertTrue(request.containsHeader(HttpHeaders.AUTHORIZATION));
+ }
+
+ @Test
+ void testProcess_CaseInsensitiveHostComparison() {
+ HttpRequest request = requestWithAuth(PATH_BLOB);
+ HttpContext context = contextWithHost(REGISTRY_HOST.toUpperCase());
+
+ interceptor.process(request, context);
+
+ assertTrue(request.containsHeader(HttpHeaders.AUTHORIZATION));
+ }
+
+ @Test
+ void testProcess_NoAuthHeader_NoError() {
+ HttpRequest request = new BasicHttpRequest(GET, PATH_BLOB);
+ HttpContext context = contextWithHost(S3_HOST);
+
+ interceptor.process(request, context);
+
+ assertFalse(request.containsHeader(HttpHeaders.AUTHORIZATION));
+ }
+
+ @ParameterizedTest
+ @ValueSource(
+ strings = {
+ "https://registry.example.com/",
+ "https://registry.example.com:443",
+ "https://registry.example.com:443/v2/",
+ "http://registry.example.com"
+ })
+ void testExtractHost_CorrectlyParsesVariousUrlFormats(String endpoint) {
+ AuthStrippingInterceptor localInterceptor = new AuthStrippingInterceptor(endpoint);
+ HttpRequest request = requestWithAuth("/v2/");
+ HttpContext context = contextWithHost(EXAMPLE_ENDPOINT_HOST);
+
+ localInterceptor.process(request, context);
+
+ assertTrue(
+ request.containsHeader(HttpHeaders.AUTHORIZATION),
+ "Should correctly extract host from: " + endpoint);
+ }
+
+ private HttpRequest requestWithAuth(String path) {
+ HttpRequest request = new BasicHttpRequest(GET, path);
+ request.setHeader(HttpHeaders.AUTHORIZATION, BEARER_TOKEN);
+ return request;
+ }
+
+ private HttpContext contextWithHost(String host) {
+ HttpContext context = new BasicHttpContext();
+ context.setAttribute(HttpClientContext.HTTP_TARGET_HOST, new HttpHost(host, 443, "https"));
+ return context;
+ }
+}
diff --git a/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AwsRegistryTest.java b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AwsRegistryTest.java
new file mode 100644
index 000000000..5aa3490d5
--- /dev/null
+++ b/registry/registry-aws/src/test/java/com/salesforce/multicloudj/registry/aws/AwsRegistryTest.java
@@ -0,0 +1,298 @@
+package com.salesforce.multicloudj.registry.aws;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.params.provider.Arguments.arguments;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException;
+import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException;
+import com.salesforce.multicloudj.common.exceptions.UnAuthorizedException;
+import com.salesforce.multicloudj.common.exceptions.UnknownException;
+import java.time.Instant;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Stream;
+import org.apache.http.HttpRequestInterceptor;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.jupiter.params.provider.NullAndEmptySource;
+import org.junit.jupiter.params.provider.ValueSource;
+import software.amazon.awssdk.awscore.exception.AwsErrorDetails;
+import software.amazon.awssdk.awscore.exception.AwsServiceException;
+import software.amazon.awssdk.services.ecr.EcrClient;
+import software.amazon.awssdk.services.ecr.model.AuthorizationData;
+import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenRequest;
+import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenResponse;
+
+class AwsRegistryTest {
+
+ private static final String TEST_REGION = "us-east-1";
+ private static final String TEST_REGISTRY_ENDPOINT =
+ "https://123456789012.dkr.ecr.us-east-1.amazonaws.com";
+ private static final String PROVIDER_ID = "aws";
+ private static final String AUTH_USERNAME = "AWS";
+ private static final String TOKEN_PREFIX = "AWS:";
+ private static final String ERR_EMPTY_AUTH_DATA = "ECR returned empty authorization data";
+ private static final String ERR_INVALID_TOKEN_FORMAT = "Invalid ECR authorization token format";
+ private static final String ERR_FAILED_AUTH_TOKEN = "Failed to get ECR authorization token";
+ private static final String AWS_ERROR_CODE_ACCESS_DENIED = "AccessDenied";
+ private static final String AWS_ERROR_CODE_UNAVAILABLE = "ServiceUnavailableException";
+ private static final long TOKEN_VALIDITY_SECONDS = 43200; // 12 hours
+
+ @FunctionalInterface
+ interface RegistryTestAction {
+ void execute(AwsRegistry registry) throws Exception;
+ }
+
+ private AwsRegistry createRegistryWithMockEcrClient(EcrClient mockEcrClient) {
+ AwsRegistry.Builder builder =
+ new AwsRegistry.Builder()
+ .withRegion(TEST_REGION)
+ .withRegistryEndpoint(TEST_REGISTRY_ENDPOINT);
+ return new AwsRegistry(builder, mockEcrClient);
+ }
+
+ private String encodeToken(String token) {
+ return Base64.getEncoder().encodeToString((TOKEN_PREFIX + token).getBytes());
+ }
+
+ private AuthorizationData authDataWithExpiry(String token, Instant expiresAt) {
+ return AuthorizationData.builder()
+ .authorizationToken(encodeToken(token))
+ .expiresAt(expiresAt)
+ .build();
+ }
+
+ private GetAuthorizationTokenResponse tokenResponse(AuthorizationData authData) {
+ return GetAuthorizationTokenResponse.builder()
+ .authorizationData(Collections.singletonList(authData))
+ .build();
+ }
+
+ private static AwsServiceException awsException(String errorCode) {
+ return AwsServiceException.builder()
+ .awsErrorDetails(AwsErrorDetails.builder().errorCode(errorCode).build())
+ .build();
+ }
+
+ private void withMockedRegistry(RegistryTestAction action) throws Exception {
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
+ action.execute(registry);
+ }
+ }
+
+ @Test
+ void testNoArgConstructor_CreatesInstanceWithDefaultBuilder() {
+ AwsRegistry registry = new AwsRegistry();
+ assertNotNull(registry);
+ assertNull(registry.getOciClient());
+ }
+
+ @Test
+ void testConstructor_WithBuilder_InitialisesFields() throws Exception {
+ withMockedRegistry(
+ registry -> {
+ assertEquals(PROVIDER_ID, registry.getProviderId());
+ assertEquals(AUTH_USERNAME, registry.getAuthUsername());
+ assertNotNull(registry.getOciClient());
+ });
+ }
+
+ @Test
+ void testBuilder_InstanceMethod_ReturnsNewBuilder() throws Exception {
+ withMockedRegistry(registry -> assertNotNull(registry.builder()));
+ }
+
+ @Test
+ void testBuilder_WithRegion_GetRegion_RoundTrip() {
+ AwsRegistry.Builder builder = new AwsRegistry.Builder().withRegion(TEST_REGION);
+ assertEquals(TEST_REGION, builder.getRegion());
+ }
+
+ @ParameterizedTest
+ @NullAndEmptySource
+ @ValueSource(strings = {" "})
+ void testBuilder_MissingEndpoint_ThrowsInvalidArgumentException(String endpoint) {
+ assertThrows(
+ InvalidArgumentException.class,
+ () ->
+ new AwsRegistry.Builder()
+ .withRegion(TEST_REGION)
+ .withRegistryEndpoint(endpoint)
+ .build());
+ }
+
+ @ParameterizedTest
+ @NullAndEmptySource
+ @ValueSource(strings = {" "})
+ void testBuilder_MissingRegion_ThrowsInvalidArgumentException(String region) {
+ assertThrows(
+ InvalidArgumentException.class,
+ () ->
+ new AwsRegistry.Builder()
+ .withRegion(region)
+ .withRegistryEndpoint(TEST_REGISTRY_ENDPOINT)
+ .build());
+ }
+
+ @Test
+ void testGetInterceptors_ReturnsAuthStrippingInterceptor() throws Exception {
+ withMockedRegistry(
+ registry -> {
+ List interceptors = registry.getInterceptors();
+ assertFalse(interceptors.isEmpty());
+ assertInstanceOf(AuthStrippingInterceptor.class, interceptors.get(0));
+ });
+ }
+
+ @Test
+ void testGetOciClient_ReturnsNonNull_WhenEndpointProvided() throws Exception {
+ withMockedRegistry(registry -> assertNotNull(registry.getOciClient()));
+ }
+
+ @Test
+ void testGetOciClient_ReturnsNull_WhenNoEndpoint() {
+ AwsRegistry registry = new AwsRegistry();
+ assertNull(registry.getOciClient());
+ }
+
+ @Test
+ void testGetAuthToken_TokenCachedWithinHalfwayWindow_NoRefresh() throws Exception {
+ String expectedToken = "cached-token";
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
+ .thenReturn(
+ tokenResponse(
+ authDataWithExpiry(
+ expectedToken, Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS))));
+
+ try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
+ assertEquals(expectedToken, registry.getAuthToken());
+ assertEquals(expectedToken, registry.getAuthToken()); // second call should use cache
+ verify(mockEcrClient, times(1))
+ .getAuthorizationToken(any(GetAuthorizationTokenRequest.class));
+ }
+ }
+
+ @Test
+ void testGetAuthToken_TokenPastHalfwayPoint_Refreshes() throws Exception {
+ String firstToken = "first-token";
+ String refreshedToken = "refreshed-token";
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
+ .thenReturn(tokenResponse(authDataWithExpiry(firstToken, Instant.now().minusSeconds(1))))
+ .thenReturn(
+ tokenResponse(
+ authDataWithExpiry(
+ refreshedToken, Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS))));
+
+ try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
+ registry.getAuthToken(); // primes cache with already-past-halfway token
+ assertEquals(refreshedToken, registry.getAuthToken()); // triggers refresh
+ verify(mockEcrClient, times(2))
+ .getAuthorizationToken(any(GetAuthorizationTokenRequest.class));
+ }
+ }
+
+ @Test
+ void testGetAuthToken_EmptyAuthorizationData_ThrowsUnknownException() throws Exception {
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
+ .thenReturn(
+ GetAuthorizationTokenResponse.builder()
+ .authorizationData(Collections.emptyList())
+ .build());
+
+ try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
+ UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken);
+ assertEquals(ERR_EMPTY_AUTH_DATA, ex.getMessage());
+ }
+ }
+
+ @Test
+ void testGetAuthToken_InvalidTokenFormat_ThrowsUnknownException() throws Exception {
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ AuthorizationData authData =
+ AuthorizationData.builder()
+ .authorizationToken(Base64.getEncoder().encodeToString("invalidtoken".getBytes()))
+ .expiresAt(Instant.now().plusSeconds(TOKEN_VALIDITY_SECONDS))
+ .build();
+ when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
+ .thenReturn(tokenResponse(authData));
+
+ try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
+ UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken);
+ assertEquals(ERR_INVALID_TOKEN_FORMAT, ex.getMessage());
+ }
+ }
+
+ @Test
+ void testGetAuthToken_RefreshFails_FallsBackToCachedToken() throws Exception {
+ String cachedToken = "still-valid-token";
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
+ .thenReturn(tokenResponse(authDataWithExpiry(cachedToken, Instant.now().minusSeconds(1))))
+ .thenThrow(awsException(AWS_ERROR_CODE_UNAVAILABLE));
+
+ try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
+ registry.getAuthToken(); // primes cache
+ assertEquals(cachedToken, registry.getAuthToken()); // falls back to cached token
+ }
+ }
+
+ @Test
+ void testGetAuthToken_RefreshFails_NoCachedToken_ThrowsUnknownException() throws Exception {
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ when(mockEcrClient.getAuthorizationToken(any(GetAuthorizationTokenRequest.class)))
+ .thenThrow(awsException(AWS_ERROR_CODE_UNAVAILABLE));
+
+ try (AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient)) {
+ UnknownException ex = assertThrows(UnknownException.class, registry::getAuthToken);
+ assertEquals(ERR_FAILED_AUTH_TOKEN, ex.getMessage());
+ }
+ }
+
+ @Test
+ void testClose_WithOciAndEcrClient_ClosesAll() throws Exception {
+ EcrClient mockEcrClient = mock(EcrClient.class);
+ AwsRegistry registry = createRegistryWithMockEcrClient(mockEcrClient);
+ registry.close();
+ verify(mockEcrClient).close();
+ }
+
+ @Test
+ void testClose_WithNullEcrClient_NoError() throws Exception {
+ AwsRegistry registry = new AwsRegistry();
+ registry.close(); // should not throw
+ }
+
+ static Stream
+ exceptionMappingProvider() { // NOSONAR - needed for @MethodSource resolution
+ return Stream.of(
+ arguments(new SubstrateSdkException("test"), SubstrateSdkException.class),
+ arguments(awsException(AWS_ERROR_CODE_ACCESS_DENIED), UnAuthorizedException.class),
+ arguments(awsException(AWS_ERROR_CODE_UNAVAILABLE), UnknownException.class),
+ arguments(new IllegalArgumentException("invalid"), InvalidArgumentException.class),
+ arguments(new RuntimeException("unknown"), UnknownException.class));
+ }
+
+ @ParameterizedTest
+ @MethodSource("exceptionMappingProvider")
+ void testGetException(Throwable input, Class extends SubstrateSdkException> expected)
+ throws Exception {
+ withMockedRegistry(registry -> assertEquals(expected, registry.getException(input)));
+ }
+}
diff --git a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java
index b52d75236..94f8cbd88 100644
--- a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java
+++ b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/client/ContainerRegistryClient.java
@@ -5,6 +5,7 @@
import com.salesforce.multicloudj.registry.driver.AbstractRegistry;
import com.salesforce.multicloudj.registry.model.Image;
import com.salesforce.multicloudj.registry.model.Platform;
+import com.salesforce.multicloudj.sts.model.CredentialsOverrider;
import java.io.InputStream;
/**
@@ -73,6 +74,12 @@ public ContainerRegistryClientBuilder withRegistryEndpoint(String registryEndpoi
return this;
}
+ /** Sets the region. Required for AWS; ignored by GCP. */
+ public ContainerRegistryClientBuilder withRegion(String region) {
+ this.registryBuilder.withRegion(region);
+ return this;
+ }
+
/** Sets a proxy endpoint override for HTTP requests. */
public ContainerRegistryClientBuilder withProxyEndpoint(java.net.URI proxyEndpoint) {
this.registryBuilder.withProxyEndpoint(proxyEndpoint);
@@ -87,7 +94,7 @@ public ContainerRegistryClientBuilder withPlatform(Platform platform) {
/** Sets credentials overrider for authentication. */
public ContainerRegistryClientBuilder withCredentialsOverrider(
- com.salesforce.multicloudj.sts.model.CredentialsOverrider credentialsOverrider) {
+ CredentialsOverrider credentialsOverrider) {
this.registryBuilder.withCredentialsOverrider(credentialsOverrider);
return this;
}
diff --git a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java
index 487f61735..33c1b80a1 100644
--- a/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java
+++ b/registry/registry-client/src/main/java/com/salesforce/multicloudj/registry/driver/AbstractRegistry.java
@@ -11,10 +11,12 @@
import com.salesforce.multicloudj.sts.model.CredentialsOverrider;
import java.io.InputStream;
import java.net.URI;
+import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Getter;
import org.apache.commons.lang3.StringUtils;
+import org.apache.http.HttpRequestInterceptor;
/** Abstract registry driver. Each cloud implements authentication and OCI client. */
public abstract class AbstractRegistry implements Provider, AutoCloseable, AuthProvider {
@@ -23,6 +25,7 @@ public abstract class AbstractRegistry implements Provider, AutoCloseable, AuthP
protected final String providerId;
protected final String registryEndpoint;
+ protected final String region;
protected final URI proxyEndpoint;
protected final CredentialsOverrider credentialsOverrider;
@Getter protected final Platform targetPlatform;
@@ -30,6 +33,7 @@ public abstract class AbstractRegistry implements Provider, AutoCloseable, AuthP
protected AbstractRegistry(Builder, ?> builder) {
this.providerId = builder.getProviderId();
this.registryEndpoint = builder.getRegistryEndpoint();
+ this.region = builder.getRegion();
this.proxyEndpoint = builder.getProxyEndpoint();
this.credentialsOverrider = builder.getCredentialsOverrider();
this.targetPlatform = builder.getPlatform() != null ? builder.getPlatform() : Platform.DEFAULT;
@@ -52,6 +56,15 @@ public String getProviderId() {
/** Returns the OCI client for this registry. */
protected abstract OciRegistryClient getOciClient();
+ /**
+ * Returns the list of HTTP request interceptors to be registered with the HTTP client. Override
+ * this method in provider-specific subclasses to add custom interceptors. By default, returns an
+ * empty list (no interceptors).
+ */
+ protected List getInterceptors() {
+ return Collections.emptyList();
+ }
+
/**
* Pulls an image from the registry (unified OCI flow).
*
@@ -110,7 +123,7 @@ public Image pull(String imageRef) {
*
*
* - OS, Architecture, Variant, and OS version must match
- *
- OS features in the spec must be a subset of the entry's OS features
+ *
- OS features in the spec must be a subset of the entry's OS features
*
- Empty/null fields in the target platform are treated as wildcards
*
*
@@ -205,10 +218,16 @@ public abstract static class BuilderThis is necessary because registries like AWS ECR redirect blob downloads to S3 pre-signed
- * URLs, which already contain authentication in query parameters. Sending an Authorization header
- * to S3 causes a 400 error ("Only one auth mechanism allowed").
- */
- private static class AuthStrippingInterceptor implements HttpRequestInterceptor {
- private final String registryHost;
-
- AuthStrippingInterceptor(String registryEndpoint) {
- this.registryHost = extractHost(registryEndpoint);
- }
-
- @Override
- public void process(HttpRequest request, HttpContext context) {
- // Get the target host from the context (works for both initial requests and redirects)
- var targetHost =
- (org.apache.http.HttpHost) context.getAttribute(HttpClientContext.HTTP_TARGET_HOST);
- if (targetHost != null) {
- String requestHost = targetHost.getHostName();
- // Strip Authorization header if request is not to the registry host
- if (!registryHost.equalsIgnoreCase(requestHost)) {
- request.removeHeaders(HttpHeaders.AUTHORIZATION);
- }
- }
- }
-
- private static String extractHost(String url) {
- String host = url.replaceFirst("^https?://", "");
- int slashIndex = host.indexOf('/');
- if (slashIndex > 0) {
- host = host.substring(0, slashIndex);
+ HttpClientBuilder builder = HttpClients.custom();
+ for (HttpRequestInterceptor interceptor : registry.getInterceptors()) {
+ builder.addInterceptorLast(interceptor);
}
- int colonIndex = host.indexOf(':');
- if (colonIndex > 0) {
- host = host.substring(0, colonIndex);
- }
- return host;
+ this.httpClient = builder.build();
}
+ this.tokenExchange = new BearerTokenExchange(this.httpClient);
}
/**
diff --git a/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java b/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java
index f199d4baf..5a2adb60a 100644
--- a/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java
+++ b/registry/registry-client/src/test/java/com/salesforce/multicloudj/registry/driver/OciRegistryClientTest.java
@@ -24,6 +24,7 @@
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
+import java.util.function.Consumer;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.HttpStatus;
@@ -40,6 +41,7 @@
import org.mockito.MockedConstruction;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
/**
* Unit tests for OciRegistryClient. Tests authentication header generation for Basic, Bearer, and
@@ -50,7 +52,7 @@ public class OciRegistryClientTest {
private static final String REGISTRY_ENDPOINT = "https://test-registry.example.com";
private static final String REPOSITORY = "test-repo/test-image";
- @Mock private AuthProvider mockAuthProvider;
+ @Mock private AbstractRegistry mockAuthProvider;
private AutoCloseable mocks;
@@ -1010,8 +1012,7 @@ private MockedStatic mockAuthChallenge() {
* @return mocked CloseableHttpClient
*/
private CloseableHttpClient createMockHttpClientWithExecuteAnswer(
- String blobContent,
- java.util.function.Consumer requestAssertion) {
+ String blobContent, Consumer requestAssertion) {
CloseableHttpClient mockHttpClient = mock(CloseableHttpClient.class);
CloseableHttpResponse mockResponse = mock(CloseableHttpResponse.class);
StatusLine mockStatusLine = mock(StatusLine.class);
@@ -1072,6 +1073,6 @@ private CloseableHttpClient createMockHttpClientForBlob(
@FunctionalInterface
interface ManifestAssertion {
- void assertManifest(com.salesforce.multicloudj.registry.model.Manifest manifest);
+ void assertManifest(Manifest manifest);
}
}