Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions registry/registry-aws/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@
<version>5.12.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>5.12.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-junit-jupiter</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.salesforce.multicloudj.registry.aws;

import org.apache.http.HttpHeaders;
import org.apache.http.HttpHost;
import org.apache.http.HttpRequest;
import org.apache.http.HttpRequestInterceptor;
import org.apache.http.client.protocol.HttpClientContext;
import org.apache.http.protocol.HttpContext;

import java.net.URI;

/**
* HTTP request interceptor that strips Authorization headers when the request target
* is not the registry host.
*
* <p>This is necessary because AWS ECR redirects blob downloads to S3 pre-signed URLs,
* which already contain authentication in query parameters. Sending an Authorization header to S3
* causes a 400 error ("Only one auth mechanism allowed").
*/
public class AuthStrippingInterceptor implements HttpRequestInterceptor {
private final String registryHost;

/**
* @param registryEndpoint the registry base URL
*/
public AuthStrippingInterceptor(String registryEndpoint) {
this.registryHost = extractHost(registryEndpoint);
}

@Override
public void process(HttpRequest request, HttpContext context) {
HttpHost targetHost = (HttpHost) context.getAttribute(HttpClientContext.HTTP_TARGET_HOST);
if (targetHost == null) {
return;
}
if (!registryHost.equalsIgnoreCase(targetHost.getHostName())) {
request.removeHeaders(HttpHeaders.AUTHORIZATION);
}
}

/**
* Extracts the hostname from a URL, stripping scheme, port, and path.
*/
private static String extractHost(String url) {
return URI.create(url).getHost();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
package com.salesforce.multicloudj.registry.aws;

import com.google.auto.service.AutoService;
import com.salesforce.multicloudj.common.aws.AwsConstants;
import com.salesforce.multicloudj.common.aws.CommonErrorCodeMapping;
import com.salesforce.multicloudj.common.aws.CredentialsProvider;
import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException;
import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException;
import com.salesforce.multicloudj.common.exceptions.UnknownException;
import com.salesforce.multicloudj.registry.driver.AbstractRegistry;
import com.salesforce.multicloudj.registry.driver.OciRegistryClient;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpRequestInterceptor;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.ecr.EcrClient;
import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenRequest;
import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenResponse;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.List;

/**
* AWS Elastic Container Registry (ECR) implementation.
*
* <p>Authentication uses ECR's GetAuthorizationToken API which returns a Base64-encoded
* {@code AWS:<password>} pair. The token is valid for 12 hours and is cached in memory.
* It is proactively refreshed at the halfway point of its validity window (6 hours).
* If a refresh fails and a cached tokenis available, the cached token is reused as a fallback
* rather than failing the request.
*/
@AutoService(AbstractRegistry.class)
public class AwsRegistry extends AbstractRegistry {

private static final String AWS_AUTH_USERNAME = "AWS";

/** Lock for thread-safe lazy initialization of ECR client. */
private final Object ecrClientLock = new Object();

/** Lock for thread-safe token refresh. */
private final Object tokenLock = new Object();

private final OciRegistryClient ociClient;

/** Lazily initialized ECR client with double-checked locking. */
private volatile EcrClient ecrClient;

private volatile String cachedAuthToken;
private volatile long tokenRequestedAt;
private volatile long tokenExpirationTime;

public AwsRegistry() {
this(new Builder());
}

public AwsRegistry(Builder builder) {
this(builder, null);
}

/**
* Creates AwsRegistry with specified EcrClient.
*
* @param builder the builder with configuration
* @param ecrClient the ECR client to use (null to create default)
*/
public AwsRegistry(Builder builder, EcrClient ecrClient) {
super(builder);
this.ecrClient = ecrClient;
this.ociClient = registryEndpoint != null ? new OciRegistryClient(registryEndpoint, this) : null;
}

@Override
public Builder builder() {
return new Builder();
}

@Override
protected OciRegistryClient getOciClient() {
return ociClient;
}

@Override
public String getAuthUsername() {
return AWS_AUTH_USERNAME;
}

@Override
public String getAuthToken() {
if (cachedAuthToken == null || isPastRefreshPoint()) {
synchronized (tokenLock) {
if (cachedAuthToken == null || isPastRefreshPoint()) {
refreshAuthToken();
}
}
}
return cachedAuthToken;
}

@Override
protected List<HttpRequestInterceptor> getInterceptors() {
return List.of(new AuthStrippingInterceptor(registryEndpoint));
}

/**
* Returns true if the current time is past the halfway point of the token's validity window,
* at which point a proactive refresh is triggered i.e. treats tokens as invalid after 50% of their lifetime.
*/
private boolean isPastRefreshPoint() {
long halfwayPoint = tokenRequestedAt + (tokenExpirationTime - tokenRequestedAt) / 2;
return System.currentTimeMillis() >= halfwayPoint;
}

/**
* Returns the ECR client, initializing lazily with double-checked locking.
*/
private EcrClient getOrCreateEcrClient() {
if (ecrClient == null) {
synchronized (ecrClientLock) {
if (ecrClient == null) {
ecrClient = createEcrClient();
}
}
}
return ecrClient;
}

private EcrClient createEcrClient() {
Region awsRegion = Region.of(region);
AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
if (credentialsOverrider != null) {
AwsCredentialsProvider overrideProvider = CredentialsProvider.getCredentialsProvider(
credentialsOverrider, awsRegion);
if (overrideProvider != null) {
credentialsProvider = overrideProvider;
}
}
return EcrClient.builder()
.region(awsRegion)
.credentialsProvider(credentialsProvider)
.build();
}

/**
* Fetches a fresh ECR authorization token and updates the cache.
* On {@link AwsServiceException}, falls back to the existing cached token if one is available.
*/
private void refreshAuthToken() {
try {
GetAuthorizationTokenResponse response = getOrCreateEcrClient().getAuthorizationToken(
GetAuthorizationTokenRequest.builder().build());

if (response.authorizationData().isEmpty()) {
throw new UnknownException("ECR returned empty authorization data");
}

// ECR token is Base64-encoded "AWS:<password>"; extract the password portion
String encodedToken = response.authorizationData().get(0).authorizationToken();
String decodedToken = new String(Base64.getDecoder().decode(encodedToken), StandardCharsets.UTF_8);
String[] parts = decodedToken.split(":", 2);
if (parts.length != 2) {
throw new UnknownException("Invalid ECR authorization token format");
}

cachedAuthToken = parts[1];
tokenRequestedAt = System.currentTimeMillis();
tokenExpirationTime = response.authorizationData().get(0).expiresAt().toEpochMilli();
} catch (AwsServiceException e) {
if (cachedAuthToken != null) {
return;
}
throw new UnknownException("Failed to get ECR authorization token", e);
}
}

@Override
public Class<? extends SubstrateSdkException> getException(Throwable t) {
if (t instanceof SubstrateSdkException) {
return (Class<? extends SubstrateSdkException>) t.getClass();
} else if (t instanceof AwsServiceException) {
AwsServiceException awsException = (AwsServiceException) t;
String errorCode = awsException.awsErrorDetails().errorCode();
Class<? extends SubstrateSdkException> mappedException = CommonErrorCodeMapping.get().get(errorCode);
return mappedException != null ? mappedException : UnknownException.class;
} else if (t instanceof IllegalArgumentException) {
return InvalidArgumentException.class;
}
return UnknownException.class;
}

@Override
public void close() throws Exception {
if (ociClient != null) {
ociClient.close();
}
if (ecrClient != null) {
ecrClient.close();
}
}

public static final class Builder extends AbstractRegistry.Builder<AwsRegistry, Builder> {

public Builder() {
providerId(AwsConstants.PROVIDER_ID);
}

@Override
public Builder self() {
return this;
}

@Override
public AwsRegistry build() {
if (StringUtils.isBlank(registryEndpoint)) {
throw new InvalidArgumentException("Registry endpoint is required for AWS ECR");
}
if (StringUtils.isBlank(region)) {
throw new InvalidArgumentException("AWS region is required");
}
return new AwsRegistry(this);
}
}
}
Loading
Loading