Skip to content

Commit 65f76e9

Browse files
committed
registry: add AWS registry implementation
Add AwsRegistry, an AWS-specific implementation of AbstractRegistry that uses ECR token-based authentication. Add AuthStrippingInterceptor to strip Authorization headers for non-registry hosts, preventing credential leakage on redirects. Add unit tests for both classes, and add junit-jupiter-params test dependency to registry-aws pom.xml.
1 parent 2a43843 commit 65f76e9

File tree

5 files changed

+618
-0
lines changed

5 files changed

+618
-0
lines changed

registry/registry-aws/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@
9090
<version>5.12.1</version>
9191
<scope>test</scope>
9292
</dependency>
93+
<dependency>
94+
<groupId>org.junit.jupiter</groupId>
95+
<artifactId>junit-jupiter-params</artifactId>
96+
<version>5.12.1</version>
97+
<scope>test</scope>
98+
</dependency>
9399
<dependency>
94100
<groupId>org.mockito</groupId>
95101
<artifactId>mockito-junit-jupiter</artifactId>
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package com.salesforce.multicloudj.registry.aws;
2+
3+
import org.apache.http.HttpHeaders;
4+
import org.apache.http.HttpHost;
5+
import org.apache.http.HttpRequest;
6+
import org.apache.http.HttpRequestInterceptor;
7+
import org.apache.http.client.protocol.HttpClientContext;
8+
import org.apache.http.protocol.HttpContext;
9+
10+
import java.net.URI;
11+
12+
/**
13+
* HTTP request interceptor that strips Authorization headers when the request target
14+
* is not the registry host.
15+
*
16+
* <p>This is necessary because AWS ECR redirects blob downloads to S3 pre-signed URLs,
17+
* which already contain authentication in query parameters. Sending an Authorization header to S3
18+
* causes a 400 error ("Only one auth mechanism allowed").
19+
*/
20+
public class AuthStrippingInterceptor implements HttpRequestInterceptor {
21+
private final String registryHost;
22+
23+
/**
24+
* @param registryEndpoint the registry base URL
25+
*/
26+
public AuthStrippingInterceptor(String registryEndpoint) {
27+
this.registryHost = extractHost(registryEndpoint);
28+
}
29+
30+
@Override
31+
public void process(HttpRequest request, HttpContext context) {
32+
HttpHost targetHost = (HttpHost) context.getAttribute(HttpClientContext.HTTP_TARGET_HOST);
33+
if (targetHost == null) {
34+
return;
35+
}
36+
if (!registryHost.equalsIgnoreCase(targetHost.getHostName())) {
37+
request.removeHeaders(HttpHeaders.AUTHORIZATION);
38+
}
39+
}
40+
41+
/**
42+
* Extracts the hostname from a URL, stripping scheme, port, and path.
43+
*/
44+
private static String extractHost(String url) {
45+
return URI.create(url).getHost();
46+
}
47+
}
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
package com.salesforce.multicloudj.registry.aws;
2+
3+
import com.google.auto.service.AutoService;
4+
import com.salesforce.multicloudj.common.aws.CommonErrorCodeMapping;
5+
import com.salesforce.multicloudj.common.aws.CredentialsProvider;
6+
import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException;
7+
import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException;
8+
import com.salesforce.multicloudj.common.exceptions.UnknownException;
9+
import com.salesforce.multicloudj.registry.driver.AbstractRegistry;
10+
import com.salesforce.multicloudj.registry.driver.OciRegistryClient;
11+
import org.apache.commons.lang3.StringUtils;
12+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
13+
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
14+
import software.amazon.awssdk.awscore.exception.AwsServiceException;
15+
import software.amazon.awssdk.regions.Region;
16+
import software.amazon.awssdk.services.ecr.EcrClient;
17+
import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenRequest;
18+
import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenResponse;
19+
20+
import java.nio.charset.StandardCharsets;
21+
import java.util.Base64;
22+
23+
/**
24+
* AWS Elastic Container Registry (ECR) implementation.
25+
*
26+
* <p>Authentication uses ECR's GetAuthorizationToken API which returns a Base64-encoded
27+
* {@code AWS:<password>} pair. The token is valid for 12 hours and is cached in memory.
28+
* It is proactively refreshed at the halfway point of its validity window (6 hours).
29+
* If a refresh fails and a cached tokenis available, the cached token is reused as a fallback
30+
* rather than failing the request.
31+
*/
32+
@AutoService(AbstractRegistry.class)
33+
public class AwsRegistry extends AbstractRegistry {
34+
35+
public static final String PROVIDER_ID = "aws";
36+
private static final String AWS_AUTH_USERNAME = "AWS";
37+
38+
/** Lock for thread-safe lazy initialization of ECR client. */
39+
private final Object ecrClientLock = new Object();
40+
41+
/** Lock for thread-safe token refresh. */
42+
private final Object tokenLock = new Object();
43+
44+
private final String region;
45+
private final OciRegistryClient ociClient;
46+
47+
/** Lazily initialized ECR client with double-checked locking. */
48+
private volatile EcrClient ecrClient;
49+
50+
private volatile String cachedAuthToken;
51+
private volatile long tokenRequestedAt;
52+
private volatile long tokenExpirationTime;
53+
54+
public AwsRegistry() {
55+
this(new Builder());
56+
}
57+
58+
public AwsRegistry(Builder builder) {
59+
this(builder, null);
60+
}
61+
62+
/**
63+
* Creates AwsRegistry with specified EcrClient.
64+
*
65+
* @param builder the builder with configuration
66+
* @param ecrClient the ECR client to use (null to create default)
67+
*/
68+
public AwsRegistry(Builder builder, EcrClient ecrClient) {
69+
super(builder);
70+
this.region = builder.getRegion();
71+
this.ecrClient = ecrClient;
72+
this.ociClient = registryEndpoint != null ? new OciRegistryClient(registryEndpoint, this) : null;
73+
}
74+
75+
@Override
76+
public Builder builder() {
77+
return new Builder();
78+
}
79+
80+
@Override
81+
protected OciRegistryClient getOciClient() {
82+
return ociClient;
83+
}
84+
85+
@Override
86+
public String getAuthUsername() {
87+
return AWS_AUTH_USERNAME;
88+
}
89+
90+
@Override
91+
public String getAuthToken() {
92+
if (cachedAuthToken == null || isPastRefreshPoint()) {
93+
synchronized (tokenLock) {
94+
if (cachedAuthToken == null || isPastRefreshPoint()) {
95+
refreshAuthToken();
96+
}
97+
}
98+
}
99+
return cachedAuthToken;
100+
}
101+
102+
/**
103+
* Returns true if the current time is past the halfway point of the token's validity window,
104+
* at which point a proactive refresh is triggered i.e. treats tokens as invalid after 50% of their lifetime.
105+
*/
106+
private boolean isPastRefreshPoint() {
107+
long halfwayPoint = tokenRequestedAt + (tokenExpirationTime - tokenRequestedAt) / 2;
108+
return System.currentTimeMillis() >= halfwayPoint;
109+
}
110+
111+
/**
112+
* Returns the ECR client, initializing lazily with double-checked locking.
113+
*/
114+
private EcrClient getOrCreateEcrClient() {
115+
if (ecrClient == null) {
116+
synchronized (ecrClientLock) {
117+
if (ecrClient == null) {
118+
ecrClient = createEcrClient();
119+
}
120+
}
121+
}
122+
return ecrClient;
123+
}
124+
125+
private EcrClient createEcrClient() {
126+
Region awsRegion = Region.of(region);
127+
AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
128+
if (credentialsOverrider != null) {
129+
AwsCredentialsProvider overrideProvider = CredentialsProvider.getCredentialsProvider(
130+
credentialsOverrider, awsRegion);
131+
if (overrideProvider != null) {
132+
credentialsProvider = overrideProvider;
133+
}
134+
}
135+
return EcrClient.builder()
136+
.region(awsRegion)
137+
.credentialsProvider(credentialsProvider)
138+
.build();
139+
}
140+
141+
/**
142+
* Fetches a fresh ECR authorization token and updates the cache.
143+
* On {@link AwsServiceException}, falls back to the existing cached token if one is available.
144+
*/
145+
private void refreshAuthToken() {
146+
try {
147+
GetAuthorizationTokenResponse response = getOrCreateEcrClient().getAuthorizationToken(
148+
GetAuthorizationTokenRequest.builder().build());
149+
150+
if (response.authorizationData().isEmpty()) {
151+
throw new UnknownException("ECR returned empty authorization data");
152+
}
153+
154+
// ECR token is Base64-encoded "AWS:<password>"; extract the password portion
155+
String encodedToken = response.authorizationData().get(0).authorizationToken();
156+
String decodedToken = new String(Base64.getDecoder().decode(encodedToken), StandardCharsets.UTF_8);
157+
String[] parts = decodedToken.split(":", 2);
158+
if (parts.length != 2) {
159+
throw new UnknownException("Invalid ECR authorization token format");
160+
}
161+
162+
cachedAuthToken = parts[1];
163+
tokenRequestedAt = System.currentTimeMillis();
164+
tokenExpirationTime = response.authorizationData().get(0).expiresAt().toEpochMilli();
165+
} catch (AwsServiceException e) {
166+
if (cachedAuthToken != null) {
167+
return;
168+
}
169+
throw new UnknownException("Failed to get ECR authorization token", e);
170+
}
171+
}
172+
173+
@Override
174+
public Class<? extends SubstrateSdkException> getException(Throwable t) {
175+
if (t instanceof SubstrateSdkException) {
176+
return (Class<? extends SubstrateSdkException>) t.getClass();
177+
} else if (t instanceof AwsServiceException) {
178+
AwsServiceException awsException = (AwsServiceException) t;
179+
String errorCode = awsException.awsErrorDetails().errorCode();
180+
Class<? extends SubstrateSdkException> mappedException = CommonErrorCodeMapping.get().get(errorCode);
181+
return mappedException != null ? mappedException : UnknownException.class;
182+
} else if (t instanceof IllegalArgumentException) {
183+
return InvalidArgumentException.class;
184+
}
185+
return UnknownException.class;
186+
}
187+
188+
@Override
189+
public void close() throws Exception {
190+
if (ociClient != null) {
191+
ociClient.close();
192+
}
193+
if (ecrClient != null) {
194+
ecrClient.close();
195+
}
196+
}
197+
198+
public static final class Builder extends AbstractRegistry.Builder<AwsRegistry, Builder> {
199+
200+
public Builder() {
201+
providerId(PROVIDER_ID);
202+
}
203+
204+
private String region;
205+
206+
public Builder withRegion(String region) {
207+
this.region = region;
208+
return this;
209+
}
210+
211+
public String getRegion() {
212+
return region;
213+
}
214+
215+
@Override
216+
public Builder self() {
217+
return this;
218+
}
219+
220+
@Override
221+
public AwsRegistry build() {
222+
if (StringUtils.isBlank(registryEndpoint)) {
223+
throw new InvalidArgumentException("Registry endpoint is required for AWS ECR");
224+
}
225+
if (StringUtils.isBlank(region)) {
226+
throw new InvalidArgumentException("AWS region is required");
227+
}
228+
return new AwsRegistry(this);
229+
}
230+
}
231+
}

0 commit comments

Comments
 (0)