Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ public String getProviderId() {
/** Returns a builder instance for this registry type. */
public abstract Builder<?, ?> builder();

@Override
public abstract String getAuthUsername();

@Override
public abstract String getAuthToken();

/** Returns the OCI client for this registry. */
protected abstract OciRegistryClient getOciClient();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
*/
public class BearerTokenExchange {

private static final String TOKEN_FIELD = "token";
private static final String ACCESS_TOKEN_FIELD = "access_token";

private final CloseableHttpClient httpClient;

/**
Expand Down Expand Up @@ -59,47 +62,46 @@ public String getBearerToken(
throw new InvalidArgumentException("Bearer challenge missing realm");
}

URI tokenUri = buildTokenUri(realm, challenge, repository, actions);
HttpGet request = new HttpGet(tokenUri);
request.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + identityToken);
try {
// Build token request URL using URIBuilder for proper encoding
URI tokenUri = buildTokenUri(realm, challenge, repository, actions);

HttpGet request = new HttpGet(tokenUri);
// Use identity token as Bearer auth for the token endpoint
request.setHeader(HttpHeaders.AUTHORIZATION, "Bearer " + identityToken);

try (CloseableHttpResponse response = httpClient.execute(request)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode != HttpStatus.SC_OK) {
String errorBody = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
throw new UnAuthorizedException(
"Token exchange failed: HTTP " + statusCode + " - " + errorBody);
}

String responseBody = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);

JsonObject json;
try {
json = JsonParser.parseString(responseBody).getAsJsonObject();
} catch (JsonSyntaxException e) {
throw new UnknownException(
"Invalid JSON response from token endpoint: " + responseBody, e);
}

// Token can be in "token" (Docker Hub, AWS ECR) or "access_token" (GCP Artifact Registry)
// field
if (json.has("token") && !json.get("token").isJsonNull()) {
return json.get("token").getAsString();
} else if (json.has("access_token") && !json.get("access_token").isJsonNull()) {
return json.get("access_token").getAsString();
}

throw new UnknownException("Token response missing token field");
}
return executeTokenRequest(request);
} catch (IOException e) {
throw new UnknownException("Token exchange request failed", e);
}
}

private String executeTokenRequest(HttpGet request) throws IOException {
try (CloseableHttpResponse response = httpClient.execute(request)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode != HttpStatus.SC_OK) {
String errorBody = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
throw new UnAuthorizedException(
"Token exchange failed: HTTP " + statusCode + " - " + errorBody);
}

String responseBody = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);

JsonObject json;
try {
json = JsonParser.parseString(responseBody).getAsJsonObject();
} catch (JsonSyntaxException e) {
throw new UnknownException(
"Invalid JSON response from token endpoint: " + responseBody, e);
}

// Token field is "token" (Docker Hub, AWS ECR) or "access_token" (GCP Artifact Registry)
if (json.has(TOKEN_FIELD) && !json.get(TOKEN_FIELD).isJsonNull()) {
return json.get(TOKEN_FIELD).getAsString();
} else if (json.has(ACCESS_TOKEN_FIELD) && !json.get(ACCESS_TOKEN_FIELD).isJsonNull()) {
return json.get(ACCESS_TOKEN_FIELD).getAsString();
}

throw new UnknownException("Token response missing token field");
}
}

private URI buildTokenUri(
String realm, AuthChallenge challenge, String repository, String[] actions) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,34 +77,43 @@ public InputStream extract() {
return t;
});

executor.submit(
() -> {
try (TarArchiveOutputStream tarOut = new TarArchiveOutputStream(pipedOut)) {
tarOut.setLongFileMode(TarArchiveOutputStream.LONGFILE_POSIX);
tarOut.setBigNumberMode(TarArchiveOutputStream.BIGNUMBER_POSIX);

// Track files to handle whiteouts and overwrites
Set<String> seenPaths = new HashSet<>();
Set<String> deletedPaths = new HashSet<>();
Set<String> opaqueDirectories = new HashSet<>();

// Process layers in reverse order (top to bottom) - most recent layer first
// This ensures that the top layer's files take precedence
for (int i = layers.size() - 1; i >= 0; i--) {
processLayer(layers.get(i), tarOut, seenPaths, deletedPaths, opaqueDirectories);
}

tarOut.finish();
} catch (Throwable t) {
extractionError.set(t);
} finally {
try {
pipedOut.close();
} catch (IOException ignored) {
// Ignore close errors
try {
executor.submit(
() -> {
try (TarArchiveOutputStream tarOut = new TarArchiveOutputStream(pipedOut)) {
tarOut.setLongFileMode(TarArchiveOutputStream.LONGFILE_POSIX);
tarOut.setBigNumberMode(TarArchiveOutputStream.BIGNUMBER_POSIX);

// Track files to handle whiteouts and overwrites
Set<String> seenPaths = new HashSet<>();
Set<String> deletedPaths = new HashSet<>();
Set<String> opaqueDirectories = new HashSet<>();

// Process layers in reverse order (top to bottom) - most recent layer first
// This ensures that the top layer's files take precedence
for (int i = layers.size() - 1; i >= 0; i--) {
processLayer(layers.get(i), tarOut, seenPaths, deletedPaths, opaqueDirectories);
}

tarOut.finish();
} catch (Throwable t) {
extractionError.set(t);
} finally {
try {
pipedOut.close();
} catch (IOException ignored) {
// Ignore close errors
}
}
}
});
});
} catch (Exception e) {
try {
pipedOut.close();
} catch (IOException ignored) {
// Ignore close errors
}
throw new UnknownException("Failed to start layer extraction", e);
}

// Return a wrapper that checks for extraction errors and cleans up executor on close
return new ExtractionInputStream(pipedIn, extractionError, executor);
Expand All @@ -126,11 +135,8 @@ private void processLayer(
while ((entry = tarIn.getNextTarEntry()) != null) {
String name = normalizePath(entry.getName());

if (handleWhiteout(name, deletedPaths, opaqueDirectories)) {
continue;
}

if (shouldSkipEntry(name, seenPaths, deletedPaths, opaqueDirectories)) {
if (handleWhiteout(name, deletedPaths, opaqueDirectories)
|| shouldSkipEntry(name, seenPaths, deletedPaths, opaqueDirectories)) {
continue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public class OciRegistryClient implements AutoCloseable {
private static final String DIGEST_ALGORITHM = "SHA-256";
private static final String DIGEST_PREFIX = "sha256:";
private static final int MAX_MANIFEST_SIZE_BYTES = 100 * 1024 * 1024; // 100 MB
private static final String MANIFEST_ERROR_FORMAT =
"Failed to fetch manifest for %s:%s from %s - HTTP %d: %s";

private final String registryEndpoint;
private final CloseableHttpClient httpClient;
Expand Down Expand Up @@ -170,88 +172,84 @@ private String buildBearerAuthHeader(String repository) {
* @throws UnknownException if the request fails or manifest cannot be parsed
*/
public Manifest fetchManifest(String repository, String reference) {
String url = String.format("%s/v2/%s/manifests/%s", registryEndpoint, repository, reference);
HttpGet request = buildFetchManifestRequest(repository, reference);
try {
return executeFetchManifestRequest(request, repository, reference);
} catch (IOException e) {
throw new UnknownException("Failed to fetch manifest", e);
}
}

private HttpGet buildFetchManifestRequest(String repository, String reference) {
String url = String.format("%s/v2/%s/manifests/%s", registryEndpoint, repository, reference);
HttpGet request = new HttpGet(url);
request.setHeader(HttpHeaders.ACCEPT, MANIFEST_ACCEPT_HEADER);
String authHeader = getHttpAuthHeader(repository);
if (authHeader != null) {
request.setHeader(HttpHeaders.AUTHORIZATION, authHeader);
}
return request;
}

try {
try (CloseableHttpResponse response = httpClient.execute(request)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode != HttpStatus.SC_OK) {
String errorBody =
response.getEntity() != null
? EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8)
: StringUtils.EMPTY;
if (statusCode == HttpStatus.SC_NOT_FOUND) {
throw new ResourceNotFoundException(
String.format(
"Failed to fetch manifest for %s:%s from %s - HTTP %d: %s",
repository, reference, registryEndpoint, statusCode, errorBody));
}
if (statusCode == HttpStatus.SC_UNAUTHORIZED) {
throw new UnAuthorizedException(
String.format(
"Failed to fetch manifest for %s:%s from %s - HTTP %d: %s",
repository, reference, registryEndpoint, statusCode, errorBody));
}
throw new UnknownException(
String.format(
"Failed to fetch manifest for %s:%s from %s - HTTP %d: %s",
repository, reference, registryEndpoint, statusCode, errorBody));
}
private Manifest executeFetchManifestRequest(
HttpGet request, String repository, String reference) throws IOException {
try (CloseableHttpResponse response = httpClient.execute(request)) {
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode != HttpStatus.SC_OK) {
String errorBody =
response.getEntity() != null
? EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8)
: StringUtils.EMPTY;
String message =
String.format(
MANIFEST_ERROR_FORMAT,
repository, reference, registryEndpoint, statusCode, errorBody);
throw mapHttpStatusToException(statusCode, message);
}

if (response.getEntity() == null) {
throw new UnknownException("Failed to fetch manifest: empty response body");
}
if (response.getEntity() == null) {
throw new UnknownException("Failed to fetch manifest: empty response body");
}

// Check manifest size limit (100 MB)
long contentLength = response.getEntity().getContentLength();
if (contentLength > MAX_MANIFEST_SIZE_BYTES) {
throw new UnknownException(
String.format(
"Manifest size (%d bytes) exceeds maximum allowed size (%d bytes)",
contentLength, MAX_MANIFEST_SIZE_BYTES));
}
// Check manifest size limit to prevent resource exhaustion
long contentLength = response.getEntity().getContentLength();
if (contentLength > MAX_MANIFEST_SIZE_BYTES) {
throw new UnknownException(
String.format(
"Manifest size (%d bytes) exceeds maximum allowed size (%d bytes)",
contentLength, MAX_MANIFEST_SIZE_BYTES));
}

String responseBody = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
Header header = response.getFirstHeader("Docker-Content-Digest");
String digestHeader = header != null ? header.getValue() : null;
String responseBody = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);
String digestHeader = computeManifestDigest(response, responseBody);

// If Docker-Content-Digest header is missing (e.g., AWS ECR), calculate it from the
// response body
if (digestHeader == null) {
try {
MessageDigest md = MessageDigest.getInstance(DIGEST_ALGORITHM);
byte[] hash = md.digest(responseBody.getBytes(StandardCharsets.UTF_8));
digestHeader = DIGEST_PREFIX + Hex.encodeHexString(hash);
} catch (NoSuchAlgorithmException e) {
throw new UnknownException(
"Failed to calculate manifest digest: "
+ DIGEST_ALGORITHM
+ " SHA-256 algorithm not available",
e);
}
}
// Validate digest if fetching by digest reference (e.g., repo@sha256:...)
if (reference.startsWith(DIGEST_PREFIX) && !reference.equals(digestHeader)) {
throw new UnknownException(
String.format(
"Manifest digest mismatch: expected %s, got %s for %s:%s",
reference, digestHeader, repository, reference));
}

// Validate digest if fetching by digest
if (reference.startsWith(DIGEST_PREFIX)) {
if (!reference.equals(digestHeader)) {
throw new UnknownException(
String.format(
"Manifest digest mismatch: expected %s, got %s for %s:%s",
reference, digestHeader, repository, reference));
}
}
return parseManifestResponse(responseBody, digestHeader);
}
}

return parseManifestResponse(responseBody, digestHeader);
}
} catch (IOException e) {
throw new UnknownException("Failed to fetch manifest", e);
private String computeManifestDigest(CloseableHttpResponse response, String responseBody) {
Header header = response.getFirstHeader("Docker-Content-Digest");
if (header != null) {
return header.getValue();
}
// Docker-Content-Digest header may be absent (e.g., AWS ECR); calculate from response body
try {
MessageDigest md = MessageDigest.getInstance(DIGEST_ALGORITHM);
byte[] hash = md.digest(responseBody.getBytes(StandardCharsets.UTF_8));
return DIGEST_PREFIX + Hex.encodeHexString(hash);
} catch (NoSuchAlgorithmException e) {
throw new UnknownException(
"Failed to calculate manifest digest: " + DIGEST_ALGORITHM
+ " SHA-256 algorithm not available",
e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,11 @@ void testDownloadBlob_SetsAuthorizationHeader() throws Exception {
HttpGet request = invocation.getArgument(0);
Header authHeader = request.getFirstHeader("Authorization");
assertNotNull(authHeader, "Authorization header should be set");
assertEquals("Basic dXNlcjp0b2tlbg==", authHeader.getValue()); // Base64("user:token")
assertEquals(
"Basic "
+ Base64.getEncoder()
.encodeToString("user:token".getBytes(StandardCharsets.UTF_8)),
authHeader.getValue());
});

try (MockedStatic<AuthChallenge> mockedAuthChallenge = mockAuthChallenge()) {
Expand Down
Loading