Skip to content

Commit 077997e

Browse files
authored
registry: add AWS registry implementation (#324)
1 parent be14812 commit 077997e

File tree

9 files changed

+737
-61
lines changed

9 files changed

+737
-61
lines changed

registry/registry-aws/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@
9090
<version>5.12.1</version>
9191
<scope>test</scope>
9292
</dependency>
93+
<dependency>
94+
<groupId>org.junit.jupiter</groupId>
95+
<artifactId>junit-jupiter-params</artifactId>
96+
<version>5.12.1</version>
97+
<scope>test</scope>
98+
</dependency>
9399
<dependency>
94100
<groupId>org.mockito</groupId>
95101
<artifactId>mockito-junit-jupiter</artifactId>
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package com.salesforce.multicloudj.registry.aws;
2+
3+
import java.net.URI;
4+
import org.apache.http.HttpHeaders;
5+
import org.apache.http.HttpHost;
6+
import org.apache.http.HttpRequest;
7+
import org.apache.http.HttpRequestInterceptor;
8+
import org.apache.http.client.protocol.HttpClientContext;
9+
import org.apache.http.protocol.HttpContext;
10+
11+
/**
12+
* HTTP request interceptor that strips Authorization headers when the request target is not the
13+
* registry host.
14+
*
15+
* <p>This is necessary because AWS ECR redirects blob downloads to S3 pre-signed URLs, which
16+
* already contain authentication in query parameters. Sending an Authorization header to S3 causes
17+
* a 400 error ("Only one auth mechanism allowed").
18+
*/
19+
public class AuthStrippingInterceptor implements HttpRequestInterceptor {
20+
private final String registryHost;
21+
22+
/**
23+
* @param registryEndpoint the registry base URL
24+
*/
25+
public AuthStrippingInterceptor(String registryEndpoint) {
26+
this.registryHost = extractHost(registryEndpoint);
27+
}
28+
29+
@Override
30+
public void process(HttpRequest request, HttpContext context) {
31+
HttpHost targetHost = (HttpHost) context.getAttribute(HttpClientContext.HTTP_TARGET_HOST);
32+
if (targetHost == null) {
33+
return;
34+
}
35+
if (!registryHost.equalsIgnoreCase(targetHost.getHostName())) {
36+
request.removeHeaders(HttpHeaders.AUTHORIZATION);
37+
}
38+
}
39+
40+
/** Extracts the hostname from a URL, stripping scheme, port, and path. */
41+
private static String extractHost(String url) {
42+
return URI.create(url).getHost();
43+
}
44+
}
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
package com.salesforce.multicloudj.registry.aws;
2+
3+
import com.google.auto.service.AutoService;
4+
import com.salesforce.multicloudj.common.aws.AwsConstants;
5+
import com.salesforce.multicloudj.common.aws.CommonErrorCodeMapping;
6+
import com.salesforce.multicloudj.common.aws.CredentialsProvider;
7+
import com.salesforce.multicloudj.common.exceptions.InvalidArgumentException;
8+
import com.salesforce.multicloudj.common.exceptions.SubstrateSdkException;
9+
import com.salesforce.multicloudj.common.exceptions.UnknownException;
10+
import com.salesforce.multicloudj.registry.driver.AbstractRegistry;
11+
import com.salesforce.multicloudj.registry.driver.OciRegistryClient;
12+
import java.nio.charset.StandardCharsets;
13+
import java.util.Base64;
14+
import java.util.List;
15+
import org.apache.commons.lang3.StringUtils;
16+
import org.apache.http.HttpRequestInterceptor;
17+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
18+
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
19+
import software.amazon.awssdk.awscore.exception.AwsServiceException;
20+
import software.amazon.awssdk.regions.Region;
21+
import software.amazon.awssdk.services.ecr.EcrClient;
22+
import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenRequest;
23+
import software.amazon.awssdk.services.ecr.model.GetAuthorizationTokenResponse;
24+
25+
/**
26+
* AWS Elastic Container Registry (ECR) implementation.
27+
*
28+
* <p>Authentication uses ECR's GetAuthorizationToken API which returns a Base64-encoded {@code
29+
* AWS:<password>} pair. The token is valid for 12 hours and is cached in memory. It is proactively
30+
* refreshed at the halfway point of its validity window (6 hours). If a refresh fails and a cached
31+
* tokenis available, the cached token is reused as a fallback rather than failing the request.
32+
*/
33+
@AutoService(AbstractRegistry.class)
34+
public class AwsRegistry extends AbstractRegistry {
35+
36+
private static final String AWS_AUTH_USERNAME = "AWS";
37+
38+
/** Lock for thread-safe lazy initialization of ECR client. */
39+
private final Object ecrClientLock = new Object();
40+
41+
/** Lock for thread-safe token refresh. */
42+
private final Object tokenLock = new Object();
43+
44+
private final OciRegistryClient ociClient;
45+
46+
/** Lazily initialized ECR client with double-checked locking. */
47+
private volatile EcrClient ecrClient;
48+
49+
private volatile String cachedAuthToken;
50+
private volatile long tokenRequestedAt;
51+
private volatile long tokenExpirationTime;
52+
53+
public AwsRegistry() {
54+
this(new Builder());
55+
}
56+
57+
public AwsRegistry(Builder builder) {
58+
this(builder, null);
59+
}
60+
61+
/**
62+
* Creates AwsRegistry with specified EcrClient.
63+
*
64+
* @param builder the builder with configuration
65+
* @param ecrClient the ECR client to use (null to create default)
66+
*/
67+
public AwsRegistry(Builder builder, EcrClient ecrClient) {
68+
super(builder);
69+
this.ecrClient = ecrClient;
70+
this.ociClient =
71+
registryEndpoint != null ? new OciRegistryClient(registryEndpoint, this) : null;
72+
}
73+
74+
@Override
75+
public Builder builder() {
76+
return new Builder();
77+
}
78+
79+
@Override
80+
protected OciRegistryClient getOciClient() {
81+
return ociClient;
82+
}
83+
84+
@Override
85+
public String getAuthUsername() {
86+
return AWS_AUTH_USERNAME;
87+
}
88+
89+
@Override
90+
public String getAuthToken() {
91+
if (cachedAuthToken == null || isPastRefreshPoint()) {
92+
synchronized (tokenLock) {
93+
if (cachedAuthToken == null || isPastRefreshPoint()) {
94+
refreshAuthToken();
95+
}
96+
}
97+
}
98+
return cachedAuthToken;
99+
}
100+
101+
@Override
102+
protected List<HttpRequestInterceptor> getInterceptors() {
103+
return List.of(new AuthStrippingInterceptor(registryEndpoint));
104+
}
105+
106+
/**
107+
* Returns true if the current time is past the halfway point of the token's validity window, at
108+
* which point a proactive refresh is triggered i.e. treats tokens as invalid after 50% of their
109+
* lifetime.
110+
*/
111+
private boolean isPastRefreshPoint() {
112+
long halfwayPoint = tokenRequestedAt + (tokenExpirationTime - tokenRequestedAt) / 2;
113+
return System.currentTimeMillis() >= halfwayPoint;
114+
}
115+
116+
/** Returns the ECR client, initializing lazily with double-checked locking. */
117+
private EcrClient getOrCreateEcrClient() {
118+
if (ecrClient == null) {
119+
synchronized (ecrClientLock) {
120+
if (ecrClient == null) {
121+
ecrClient = createEcrClient();
122+
}
123+
}
124+
}
125+
return ecrClient;
126+
}
127+
128+
private EcrClient createEcrClient() {
129+
Region awsRegion = Region.of(region);
130+
AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.create();
131+
if (credentialsOverrider != null) {
132+
AwsCredentialsProvider overrideProvider =
133+
CredentialsProvider.getCredentialsProvider(credentialsOverrider, awsRegion);
134+
if (overrideProvider != null) {
135+
credentialsProvider = overrideProvider;
136+
}
137+
}
138+
return EcrClient.builder().region(awsRegion).credentialsProvider(credentialsProvider).build();
139+
}
140+
141+
/**
142+
* Fetches a fresh ECR authorization token and updates the cache. On {@link AwsServiceException},
143+
* falls back to the existing cached token if one is available.
144+
*/
145+
private void refreshAuthToken() {
146+
try {
147+
GetAuthorizationTokenResponse response =
148+
getOrCreateEcrClient()
149+
.getAuthorizationToken(GetAuthorizationTokenRequest.builder().build());
150+
151+
if (response.authorizationData().isEmpty()) {
152+
throw new UnknownException("ECR returned empty authorization data");
153+
}
154+
155+
// ECR token is Base64-encoded "AWS:<password>"; extract the password portion
156+
String encodedToken = response.authorizationData().get(0).authorizationToken();
157+
String decodedToken =
158+
new String(Base64.getDecoder().decode(encodedToken), StandardCharsets.UTF_8);
159+
String[] parts = decodedToken.split(":", 2);
160+
if (parts.length != 2) {
161+
throw new UnknownException("Invalid ECR authorization token format");
162+
}
163+
164+
cachedAuthToken = parts[1];
165+
tokenRequestedAt = System.currentTimeMillis();
166+
tokenExpirationTime = response.authorizationData().get(0).expiresAt().toEpochMilli();
167+
} catch (AwsServiceException e) {
168+
if (cachedAuthToken != null) {
169+
return;
170+
}
171+
throw new UnknownException("Failed to get ECR authorization token", e);
172+
}
173+
}
174+
175+
@Override
176+
public Class<? extends SubstrateSdkException> getException(Throwable t) {
177+
if (t instanceof SubstrateSdkException) {
178+
return (Class<? extends SubstrateSdkException>) t.getClass();
179+
} else if (t instanceof AwsServiceException) {
180+
AwsServiceException awsException = (AwsServiceException) t;
181+
String errorCode = awsException.awsErrorDetails().errorCode();
182+
Class<? extends SubstrateSdkException> mappedException =
183+
CommonErrorCodeMapping.get().get(errorCode);
184+
return mappedException != null ? mappedException : UnknownException.class;
185+
} else if (t instanceof IllegalArgumentException) {
186+
return InvalidArgumentException.class;
187+
}
188+
return UnknownException.class;
189+
}
190+
191+
@Override
192+
public void close() throws Exception {
193+
if (ociClient != null) {
194+
ociClient.close();
195+
}
196+
if (ecrClient != null) {
197+
ecrClient.close();
198+
}
199+
}
200+
201+
public static final class Builder extends AbstractRegistry.Builder<AwsRegistry, Builder> {
202+
203+
public Builder() {
204+
providerId(AwsConstants.PROVIDER_ID);
205+
}
206+
207+
@Override
208+
public Builder self() {
209+
return this;
210+
}
211+
212+
@Override
213+
public AwsRegistry build() {
214+
if (StringUtils.isBlank(registryEndpoint)) {
215+
throw new InvalidArgumentException("Registry endpoint is required for AWS ECR");
216+
}
217+
if (StringUtils.isBlank(region)) {
218+
throw new InvalidArgumentException("AWS region is required");
219+
}
220+
return new AwsRegistry(this);
221+
}
222+
}
223+
}

0 commit comments

Comments
 (0)