diff --git a/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java b/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java index eb68ddb44d60..9f7de9e6dff3 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java +++ b/aws/src/main/java/org/apache/iceberg/aws/AwsClientProperties.java @@ -20,7 +20,6 @@ import java.io.Serializable; import java.util.Map; -import java.util.Optional; import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.aws.s3.VendedCredentialsProvider; import org.apache.iceberg.common.DynClasses; @@ -28,7 +27,6 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.base.Strings; import org.apache.iceberg.rest.RESTUtil; -import org.apache.iceberg.rest.auth.OAuth2Properties; import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.SerializableMap; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; @@ -198,12 +196,9 @@ public void applyClientCredentialConfigurati public AwsCredentialsProvider credentialsProvider( String accessKeyId, String secretAccessKey, String sessionToken) { if (refreshCredentialsEnabled && !Strings.isNullOrEmpty(refreshCredentialsEndpoint)) { + clientCredentialsProviderProperties.putAll(allProperties); clientCredentialsProviderProperties.put( VendedCredentialsProvider.URI, refreshCredentialsEndpoint); - Optional.ofNullable(allProperties.get(OAuth2Properties.TOKEN)) - .ifPresent( - token -> - clientCredentialsProviderProperties.putIfAbsent(OAuth2Properties.TOKEN, token)); return credentialsProvider(VendedCredentialsProvider.class.getName()); } diff --git a/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java index da96f4cb8f48..fc42bd789859 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java +++ b/aws/src/main/java/org/apache/iceberg/aws/s3/VendedCredentialsProvider.java @@ -24,6 +24,7 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.base.Strings; import org.apache.iceberg.rest.ErrorHandlers; @@ -47,17 +48,23 @@ public class VendedCredentialsProvider implements AwsCredentialsProvider, SdkAut private volatile HTTPClient client; private final Map properties; private final CachedSupplier credentialCache; + private final String catalogEndpoint; + private final String credentialsEndpoint; private AuthManager authManager; private AuthSession authSession; private VendedCredentialsProvider(Map properties) { Preconditions.checkArgument(null != properties, "Invalid properties: null"); - Preconditions.checkArgument(null != properties.get(URI), "Invalid URI: null"); + Preconditions.checkArgument(null != properties.get(URI), "Invalid credentials endpoint: null"); + Preconditions.checkArgument( + null != properties.get(CatalogProperties.URI), "Invalid catalog endpoint: null"); this.properties = properties; this.credentialCache = CachedSupplier.builder(() -> credentialFromProperties().orElseGet(this::refreshCredential)) .cachedValueName(VendedCredentialsProvider.class.getName()) .build(); + this.catalogEndpoint = properties.get(CatalogProperties.URI); + this.credentialsEndpoint = properties.get(URI); } @Override @@ -82,7 +89,7 @@ private RESTClient httpClient() { synchronized (this) { if (null == client) { authManager = AuthManagers.loadAuthManager("s3-credentials-refresh", properties); - HTTPClient httpClient = HTTPClient.builder(properties).uri(properties.get(URI)).build(); + HTTPClient httpClient = HTTPClient.builder(properties).uri(catalogEndpoint).build(); authSession = authManager.catalogSession(httpClient, properties); client = httpClient.withAuthSession(authSession); } @@ -95,7 +102,7 @@ private RESTClient httpClient() { private LoadCredentialsResponse fetchCredentials() { return httpClient() .get( - properties.get(URI), + credentialsEndpoint, null, LoadCredentialsResponse.class, Map.of(), diff --git a/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java b/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java index 64b91c052631..e59b43263f99 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java +++ b/aws/src/test/java/org/apache/iceberg/aws/AwsClientPropertiesTest.java @@ -122,6 +122,8 @@ public void refreshCredentialsEndpoint() { AwsClientProperties awsClientProperties = new AwsClientProperties( ImmutableMap.of( + CatalogProperties.URI, + "http://localhost:1234/v1", AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT, "http://localhost:1234/v1/credentials")); @@ -150,6 +152,8 @@ public void refreshCredentialsEndpointWithOAuthToken() { ImmutableMap.of( AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT, "http://localhost:1234/v1/credentials", + CatalogProperties.URI, + "http://localhost:1234/v1/catalog", OAuth2Properties.TOKEN, "oauth-token")); @@ -161,36 +165,41 @@ public void refreshCredentialsEndpointWithOAuthToken() { .extracting("properties") .isEqualTo( ImmutableMap.of( + AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT, + "http://localhost:1234/v1/credentials", "credentials.uri", "http://localhost:1234/v1/credentials", + CatalogProperties.URI, + "http://localhost:1234/v1/catalog", OAuth2Properties.TOKEN, "oauth-token")); } @Test public void refreshCredentialsEndpointWithOverridingOAuthToken() { - AwsClientProperties awsClientProperties = - new AwsClientProperties( - ImmutableMap.of( - AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT, - "http://localhost:1234/v1/credentials", - OAuth2Properties.TOKEN, - "oauth-token", - "client.credentials-provider.token", - "specific-token")); + Map properties = + ImmutableMap.of( + CatalogProperties.URI, + "http://localhost:1234/v1", + AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT, + "http://localhost:1234/v1/credentials", + OAuth2Properties.TOKEN, + "oauth-token", + "client.credentials-provider.token", + "specific-token"); + AwsClientProperties awsClientProperties = new AwsClientProperties(properties); + + Map expectedProperties = + ImmutableMap.builder() + .putAll(properties) + .put("credentials.uri", "http://localhost:1234/v1/credentials") + .build(); AwsCredentialsProvider provider = awsClientProperties.credentialsProvider("key", "secret", "token"); assertThat(provider).isInstanceOf(VendedCredentialsProvider.class); VendedCredentialsProvider vendedCredentialsProvider = (VendedCredentialsProvider) provider; - assertThat(vendedCredentialsProvider) - .extracting("properties") - .isEqualTo( - ImmutableMap.of( - "credentials.uri", - "http://localhost:1234/v1/credentials", - OAuth2Properties.TOKEN, - "specific-token")); + assertThat(vendedCredentialsProvider).extracting("properties").isEqualTo(expectedProperties); } @Test @@ -213,8 +222,12 @@ public void refreshCredentialsEndpointWithRelativePath() { .extracting("properties") .isEqualTo( ImmutableMap.of( + CatalogProperties.URI, + "http://localhost:1234/v1", "credentials.uri", "http://localhost:1234/v1/relative/credentials/endpoint", + AwsClientProperties.REFRESH_CREDENTIALS_ENDPOINT, + "/relative/credentials/endpoint", OAuth2Properties.TOKEN, "oauth-token")); } diff --git a/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java b/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java index 51aca8894300..8e2f99e0ccbe 100644 --- a/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java +++ b/aws/src/test/java/org/apache/iceberg/aws/s3/TestVendedCredentialsProvider.java @@ -26,6 +26,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; +import org.apache.iceberg.CatalogProperties; import org.apache.iceberg.exceptions.RESTException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.rest.HttpMethod; @@ -48,7 +49,12 @@ public class TestVendedCredentialsProvider { private static final int PORT = 3232; - private static final String URI = String.format("http://127.0.0.1:%d/v1/credentials", PORT); + private static final String CREDENTIALS_URI = + String.format("http://127.0.0.1:%d/v1/credentials", PORT); + private static final String CATALOG_URI = String.format("http://127.0.0.1:%d/v1", PORT); + private static final ImmutableMap PROPERTIES = + ImmutableMap.of( + VendedCredentialsProvider.URI, CREDENTIALS_URI, CatalogProperties.URI, CATALOG_URI); private static ClientAndServer mockServer; @BeforeAll @@ -73,14 +79,24 @@ public void invalidOrMissingUri() { .hasMessage("Invalid properties: null"); assertThatThrownBy(() -> VendedCredentialsProvider.create(ImmutableMap.of())) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Invalid URI: null"); + .hasMessage("Invalid credentials endpoint: null"); + assertThatThrownBy( + () -> + VendedCredentialsProvider.create( + ImmutableMap.of("credentials.uri", "/credentials/uri"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid catalog endpoint: null"); try (VendedCredentialsProvider provider = VendedCredentialsProvider.create( - ImmutableMap.of(VendedCredentialsProvider.URI, "invalid uri"))) { + ImmutableMap.of( + VendedCredentialsProvider.URI, + "/credentials/uri", + CatalogProperties.URI, + "invalid catalog uri"))) { assertThatThrownBy(provider::resolveCredentials) .isInstanceOf(RESTException.class) - .hasMessageStartingWith("Failed to create request URI from base invalid uri"); + .hasMessageStartingWith("Failed to create request URI from base invalid catalog uri"); } } @@ -95,8 +111,7 @@ public void noS3Credentials() { .withStatusCode(200); mockServer.when(mockRequest).respond(mockResponse); - try (VendedCredentialsProvider provider = - VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, URI))) { + try (VendedCredentialsProvider provider = VendedCredentialsProvider.create(PROPERTIES)) { assertThatThrownBy(provider::resolveCredentials) .isInstanceOf(IllegalStateException.class) .hasMessage("Invalid S3 Credentials: empty"); @@ -124,8 +139,7 @@ public void accessKeyIdAndSecretAccessKeyWithoutToken() { response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); mockServer.when(mockRequest).respond(mockResponse); - try (VendedCredentialsProvider provider = - VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, URI))) { + try (VendedCredentialsProvider provider = VendedCredentialsProvider.create(PROPERTIES)) { assertThatThrownBy(provider::resolveCredentials) .isInstanceOf(IllegalStateException.class) .hasMessage("Invalid S3 Credentials: s3.session-token not set"); @@ -155,8 +169,7 @@ public void expirationNotSet() { response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); mockServer.when(mockRequest).respond(mockResponse); - try (VendedCredentialsProvider provider = - VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, URI))) { + try (VendedCredentialsProvider provider = VendedCredentialsProvider.create(PROPERTIES)) { assertThatThrownBy(provider::resolveCredentials) .isInstanceOf(IllegalStateException.class) .hasMessage("Invalid S3 Credentials: s3.session-token-expires-at-ms not set"); @@ -187,8 +200,7 @@ public void nonExpiredToken() { response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); mockServer.when(mockRequest).respond(mockResponse); - try (VendedCredentialsProvider provider = - VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, URI))) { + try (VendedCredentialsProvider provider = VendedCredentialsProvider.create(PROPERTIES)) { AwsCredentials awsCredentials = provider.resolveCredentials(); verifyCredentials(awsCredentials, credential); @@ -226,8 +238,7 @@ public void expiredToken() { response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); mockServer.when(mockRequest).respond(mockResponse); - try (VendedCredentialsProvider provider = - VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, URI))) { + try (VendedCredentialsProvider provider = VendedCredentialsProvider.create(PROPERTIES)) { AwsCredentials awsCredentials = provider.resolveCredentials(); verifyCredentials(awsCredentials, credential); @@ -294,8 +305,7 @@ public void multipleS3Credentials() { response(LoadCredentialsResponseParser.toJson(response)).withStatusCode(200); mockServer.when(mockRequest).respond(mockResponse); - try (VendedCredentialsProvider provider = - VendedCredentialsProvider.create(ImmutableMap.of(VendedCredentialsProvider.URI, URI))) { + try (VendedCredentialsProvider provider = VendedCredentialsProvider.create(PROPERTIES)) { assertThatThrownBy(provider::resolveCredentials) .isInstanceOf(IllegalStateException.class) .hasMessage("Invalid S3 Credentials: only one S3 credential should exist"); @@ -345,8 +355,10 @@ public void nonExpiredTokenInProperties() { try (VendedCredentialsProvider provider = VendedCredentialsProvider.create( ImmutableMap.of( + CatalogProperties.URI, + CATALOG_URI, VendedCredentialsProvider.URI, - URI, + CREDENTIALS_URI, S3FileIOProperties.ACCESS_KEY_ID, "randomAccessKeyFromProperties", S3FileIOProperties.SECRET_ACCESS_KEY, @@ -397,8 +409,10 @@ public void expiredTokenInProperties() { try (VendedCredentialsProvider provider = VendedCredentialsProvider.create( ImmutableMap.of( + CatalogProperties.URI, + CATALOG_URI, VendedCredentialsProvider.URI, - URI, + CREDENTIALS_URI, S3FileIOProperties.ACCESS_KEY_ID, "randomAccessKeyFromProperties", S3FileIOProperties.SECRET_ACCESS_KEY, @@ -450,8 +464,10 @@ public void invalidTokenInProperties() { try (VendedCredentialsProvider provider = VendedCredentialsProvider.create( ImmutableMap.of( + CatalogProperties.URI, + CATALOG_URI, VendedCredentialsProvider.URI, - URI, + CREDENTIALS_URI, S3FileIOProperties.ACCESS_KEY_ID, "randomAccessKeyFromProperties", S3FileIOProperties.SECRET_ACCESS_KEY,