diff --git a/parent-pom.xml b/parent-pom.xml index 3b1b41b014..257b2331a9 100644 --- a/parent-pom.xml +++ b/parent-pom.xml @@ -611,6 +611,10 @@ com.amazonaws aws-java-sdk-s3 + + com.amazonaws + aws-java-sdk-sts + com.fasterxml.jackson.core jackson-annotations diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java index 72ed3157f2..21c24eadc3 100644 --- a/src/main/java/net/snowflake/client/core/SFLoginInput.java +++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java @@ -5,7 +5,9 @@ import java.security.PrivateKey; import java.time.Duration; import java.util.Map; +import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation; import net.snowflake.client.jdbc.ErrorCode; +import org.apache.http.client.methods.HttpRequestBase; /** A class for holding all information required for login */ public class SFLoginInput { @@ -57,6 +59,10 @@ public class SFLoginInput { private boolean enableClientStoreTemporaryCredential; private boolean enableClientRequestMfaToken; + // Workload Identity Federation + private String workloadIdentityProvider; + private WorkloadIdentityAttestation workloadIdentityAttestation; + // OAuth private int redirectUriPort = -1; private String clientId; @@ -342,6 +348,15 @@ SFLoginInput setOauthRefreshToken(String oauthRefreshToken) { return this; } + String getWorkloadIdentityProvider() { + return workloadIdentityProvider; + } + + SFLoginInput setWorkloadIdentityProvider(String workloadIdentityProvider) { + this.workloadIdentityProvider = workloadIdentityProvider; + return this; + } + @SnowflakeJdbcInternalApi public String getDPoPPublicKey() { return dpopPublicKey; @@ -505,14 +520,11 @@ Map getAdditionalHttpHeadersForSnowsight() { * Set additional http headers to apply to the outgoing request. The additional headers cannot be * used to replace or overwrite a header in use by the driver. These will be applied to the * outgoing request. Primarily used by Snowsight, as described in {@link - * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase, - * Map)} + * HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map)} * * @param additionalHttpHeaders The new headers to add * @return The input object, for chaining - * @see - * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase, - * Map) + * @see HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map) */ public SFLoginInput setAdditionalHttpHeadersForSnowsight( Map additionalHttpHeaders) { @@ -589,4 +601,13 @@ SFLoginInput setOriginalAuthenticator(String originalAuthenticator) { this.originalAuthenticator = originalAuthenticator; return this; } + + public void setWorkloadIdentityAttestation( + WorkloadIdentityAttestation workloadIdentityAttestation) { + this.workloadIdentityAttestation = workloadIdentityAttestation; + } + + public WorkloadIdentityAttestation getWorkloadIdentityAttestation() { + return workloadIdentityAttestation; + } } diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java index eafe2263d8..5a4fb1573a 100644 --- a/src/main/java/net/snowflake/client/core/SFSession.java +++ b/src/main/java/net/snowflake/client/core/SFSession.java @@ -721,6 +721,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException { .setPrivateKey((PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY)) .setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE)) .setOauthLoginInput(oauthLoginInput) + .setWorkloadIdentityProvider( + (String) connectionPropertiesMap.get(SFSessionProperty.WORKLOAD_IDENTITY_PROVIDER)) .setPrivateKeyBase64( (String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_BASE64)) .setPrivateKeyPwd( diff --git a/src/main/java/net/snowflake/client/core/SFSessionProperty.java b/src/main/java/net/snowflake/client/core/SFSessionProperty.java index bc20892d13..eba4f4ae01 100644 --- a/src/main/java/net/snowflake/client/core/SFSessionProperty.java +++ b/src/main/java/net/snowflake/client/core/SFSessionProperty.java @@ -26,6 +26,7 @@ public enum SFSessionProperty { OAUTH_SCOPE("oauthScope", false, String.class), OAUTH_AUTHORIZATION_URL("oauthAuthorizationUrl", false, String.class), OAUTH_TOKEN_REQUEST_URL("oauthTokenRequestUrl", false, String.class), + WORKLOAD_IDENTITY_PROVIDER("workloadIdentityProvider", false, String.class), WAREHOUSE("warehouse", false, String.class), LOGIN_TIMEOUT("loginTimeout", false, Integer.class), NETWORK_TIMEOUT("networkTimeout", false, Integer.class), diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java index 275ad5a5a6..44e156784b 100644 --- a/src/main/java/net/snowflake/client/core/SessionUtil.java +++ b/src/main/java/net/snowflake/client/core/SessionUtil.java @@ -32,6 +32,13 @@ import net.snowflake.client.core.auth.oauth.OAuthAccessTokenForRefreshTokenProvider; import net.snowflake.client.core.auth.oauth.OAuthAccessTokenProviderFactory; import net.snowflake.client.core.auth.oauth.TokenResponseDTO; +import net.snowflake.client.core.auth.wif.AWSAttestationService; +import net.snowflake.client.core.auth.wif.AwsIdentityAttestationCreator; +import net.snowflake.client.core.auth.wif.AzureIdentityAttestationCreator; +import net.snowflake.client.core.auth.wif.GcpIdentityAttestationCreator; +import net.snowflake.client.core.auth.wif.OidcIdentityAttestationCreator; +import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation; +import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestationProvider; import net.snowflake.client.jdbc.ErrorCode; import net.snowflake.client.jdbc.RetryContext; import net.snowflake.client.jdbc.RetryContextManager; @@ -242,6 +249,10 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) { .getAuthenticator() .equalsIgnoreCase(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN.name())) { return AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN; + } else if (loginInput + .getAuthenticator() + .equalsIgnoreCase(AuthenticatorType.WORKLOAD_IDENTITY.name())) { + return AuthenticatorType.WORKLOAD_IDENTITY; } else if (loginInput .getAuthenticator() .equalsIgnoreCase(AuthenticatorType.SNOWFLAKE_JWT.name())) { @@ -328,6 +339,24 @@ static SFLoginOutput openSession( } } + if (authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) { + WorkloadIdentityAttestationProvider attestationProvider = + new WorkloadIdentityAttestationProvider( + new AwsIdentityAttestationCreator(new AWSAttestationService()), + new GcpIdentityAttestationCreator(), + new AzureIdentityAttestationCreator(), + new OidcIdentityAttestationCreator()); + WorkloadIdentityAttestation attestation = + attestationProvider.getAttestation(loginInput.getWorkloadIdentityProvider()); + if (attestation != null) { + loginInput.setWorkloadIdentityAttestation(attestation); + } else { + throw new SFException( + ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR, + "Unable to obtain workload identity attestation. Make sure that correct workload identity provider has been set and that Snowflake-JDBC driver runs on supported environment."); + } + } + convertSessionParameterStringValueToBooleanIfGiven(loginInput, CLIENT_REQUEST_MFA_TOKEN); readCachedCredentialsIfPossible(loginInput); @@ -353,7 +382,8 @@ static SFLoginOutput openSession( static void checkIfExperimentalAuthnEnabled(AuthenticatorType authenticator) throws SFException { if (authenticator.equals(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN) || authenticator.equals(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS) - || authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) { + || authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE) + || authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) { boolean experimentalAuthenticationMethodsEnabled = Boolean.parseBoolean(systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION")); AssertUtil.assertTrue( @@ -650,6 +680,15 @@ static SFLoginOutput newSession( data.put(ClientAuthnParameter.OAUTH_TYPE.name(), loginInput.getOriginalAuthenticator()); } + if (authenticatorType == AuthenticatorType.WORKLOAD_IDENTITY) { + data.put( + ClientAuthnParameter.TOKEN.name(), + loginInput.getWorkloadIdentityAttestation().getCredential()); + data.put( + ClientAuthnParameter.PROVIDER.name(), + loginInput.getWorkloadIdentityAttestation().getProvider()); + } + // map of client environment parameters, including connection parameters // and environment properties like OS version, etc. Map clientEnv = new HashMap<>(); diff --git a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java index b6a5919395..0251c65776 100644 --- a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java +++ b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java @@ -53,5 +53,10 @@ public enum AuthenticatorType { /* * Authenticator to support PAT created in Snowflake */ - PROGRAMMATIC_ACCESS_TOKEN + PROGRAMMATIC_ACCESS_TOKEN, + + /* + * Authenticator to support existing authentication by existing AWS/GCP/Azure workload identity + */ + WORKLOAD_IDENTITY } diff --git a/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java b/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java index 6b8bc84935..2e4f2fcf83 100644 --- a/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java +++ b/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java @@ -18,5 +18,6 @@ public enum ClientAuthnParameter { SESSION_PARAMETERS, PROOF_KEY, TOKEN, - OAUTH_TYPE + OAUTH_TYPE, + PROVIDER } diff --git a/src/main/java/net/snowflake/client/core/auth/wif/AWSAttestationService.java b/src/main/java/net/snowflake/client/core/auth/wif/AWSAttestationService.java new file mode 100644 index 0000000000..6a08485ca1 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/AWSAttestationService.java @@ -0,0 +1,53 @@ +package net.snowflake.client.core.auth.wif; + +import com.amazonaws.SignableRequest; +import com.amazonaws.auth.AWS4Signer; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.regions.InstanceMetadataRegionProvider; +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; +import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest; +import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult; +import java.util.Optional; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.EnvironmentVariables; +import net.snowflake.client.jdbc.SnowflakeUtil; + +@SnowflakeJdbcInternalApi +public class AWSAttestationService { + + private static final String SECURE_TOKEN_SERVICE_NAME = "sts"; + private static boolean regionInitialized = false; + private static String region; + + private final AWS4Signer aws4Signer; + + public AWSAttestationService() { + aws4Signer = new AWS4Signer(); + aws4Signer.setServiceName(SECURE_TOKEN_SERVICE_NAME); + } + + AWSCredentials getAWSCredentials() { + return DefaultAWSCredentialsProviderChain.getInstance().getCredentials(); + } + + String getAWSRegion() { + if (!regionInitialized) { + String envRegion = SnowflakeUtil.systemGetEnv(EnvironmentVariables.AWS_REGION.getName()); + region = envRegion != null ? envRegion : new InstanceMetadataRegionProvider().getRegion(); + regionInitialized = true; + } + return region; + } + + String getArn() { + GetCallerIdentityResult callerIdentity = + AWSSecurityTokenServiceClientBuilder.defaultClient() + .getCallerIdentity(new GetCallerIdentityRequest()); + return Optional.ofNullable(callerIdentity).map(GetCallerIdentityResult::getArn).orElse(null); + } + + void signRequestWithSigV4(SignableRequest signableRequest, AWSCredentials awsCredentials) { + aws4Signer.sign(signableRequest, awsCredentials); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreator.java new file mode 100644 index 0000000000..b38f070d83 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreator.java @@ -0,0 +1,81 @@ +package net.snowflake.client.core.auth.wif; + +import com.amazonaws.DefaultRequest; +import com.amazonaws.Request; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.http.HttpMethodName; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Collections; +import net.minidev.json.JSONObject; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; + +@SnowflakeJdbcInternalApi +public class AwsIdentityAttestationCreator implements WorkloadIdentityAttestationCreator { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(AwsIdentityAttestationCreator.class); + + private static final String SNOWFLAKE_AUDIENCE_HEADER_NAME = "X-Snowflake-Audience"; + private static final String SNOWFLAKE_AUDIENCE = "snowflakecomputing.com"; + + private final AWSAttestationService attestationService; + + public AwsIdentityAttestationCreator(AWSAttestationService attestationService) { + this.attestationService = attestationService; + } + + @Override + public WorkloadIdentityAttestation createAttestation() { + logger.debug("Creating AWS identity attestation..."); + AWSCredentials awsCredentials = attestationService.getAWSCredentials(); + if (awsCredentials == null) { + logger.debug("No AWS credentials were found."); + return null; + } + String region = attestationService.getAWSRegion(); + if (region == null) { + logger.debug("No AWS region was found."); + return null; + } + String arn = attestationService.getArn(); + if (arn == null) { + logger.debug("No Caller Identity was found."); + return null; + } + + String stsHostname = String.format("sts.%s.amazonaws.com", region); + Request request = createStsRequest(stsHostname); + attestationService.signRequestWithSigV4(request, awsCredentials); + + String credential = createBase64EncodedRequestCredential(request); + return new WorkloadIdentityAttestation( + WorkloadIdentityProviderType.AWS, credential, Collections.singletonMap("arn", arn)); + } + + private Request createStsRequest(String hostname) { + Request request = new DefaultRequest<>("sts"); + request.setHttpMethod(HttpMethodName.POST); + request.setEndpoint( + URI.create( + String.format("https://%s/?Action=GetCallerIdentity&Version=2011-06-15", hostname))); + request.addHeader("Host", hostname); + request.addHeader(SNOWFLAKE_AUDIENCE_HEADER_NAME, SNOWFLAKE_AUDIENCE); + return request; + } + + private String createBase64EncodedRequestCredential(Request request) { + JSONObject assertionJson = new JSONObject(); + JSONObject headers = new JSONObject(); + headers.putAll(request.getHeaders()); + assertionJson.put("url", request.getEndpoint().toString()); + assertionJson.put("method", request.getHttpMethod().toString()); + assertionJson.put("headers", headers); + + String assertionJsonString = assertionJson.toString(); + return Base64.getEncoder().encodeToString(assertionJsonString.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/AzureIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/AzureIdentityAttestationCreator.java new file mode 100644 index 0000000000..2ff1a9c1e2 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/AzureIdentityAttestationCreator.java @@ -0,0 +1,19 @@ +package net.snowflake.client.core.auth.wif; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; + +@SnowflakeJdbcInternalApi +public class AzureIdentityAttestationCreator implements WorkloadIdentityAttestationCreator { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(AzureIdentityAttestationCreator.class); + + @Override + public WorkloadIdentityAttestation createAttestation() throws SFException { + throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "Azure Workload Identity not supported"); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/GcpIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/GcpIdentityAttestationCreator.java new file mode 100644 index 0000000000..c58c6633ca --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/GcpIdentityAttestationCreator.java @@ -0,0 +1,19 @@ +package net.snowflake.client.core.auth.wif; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; + +@SnowflakeJdbcInternalApi +public class GcpIdentityAttestationCreator implements WorkloadIdentityAttestationCreator { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(GcpIdentityAttestationCreator.class); + + @Override + public WorkloadIdentityAttestation createAttestation() throws SFException { + throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "GCP Workload Identity not supported"); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/OidcIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/OidcIdentityAttestationCreator.java new file mode 100644 index 0000000000..edbae18b29 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/OidcIdentityAttestationCreator.java @@ -0,0 +1,19 @@ +package net.snowflake.client.core.auth.wif; + +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; + +@SnowflakeJdbcInternalApi +public class OidcIdentityAttestationCreator implements WorkloadIdentityAttestationCreator { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(OidcIdentityAttestationCreator.class); + + @Override + public WorkloadIdentityAttestation createAttestation() throws SFException { + throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "OIDC Workload Identity not supported"); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestation.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestation.java new file mode 100644 index 0000000000..daf2cc32c6 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestation.java @@ -0,0 +1,33 @@ +package net.snowflake.client.core.auth.wif; + +import java.util.Map; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; + +@SnowflakeJdbcInternalApi +public class WorkloadIdentityAttestation { + + private final WorkloadIdentityProviderType provider; + private final String credential; + private final Map userIdentifiedComponents; + + WorkloadIdentityAttestation( + WorkloadIdentityProviderType provider, + String credential, + Map userIdentifiedComponents) { + this.provider = provider; + this.credential = credential; + this.userIdentifiedComponents = userIdentifiedComponents; + } + + public WorkloadIdentityProviderType getProvider() { + return provider; + } + + public String getCredential() { + return credential; + } + + public Map getUserIdentifiedComponents() { + return userIdentifiedComponents; + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationCreator.java new file mode 100644 index 0000000000..736f6a8400 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationCreator.java @@ -0,0 +1,8 @@ +package net.snowflake.client.core.auth.wif; + +import net.snowflake.client.core.SFException; + +interface WorkloadIdentityAttestationCreator { + + WorkloadIdentityAttestation createAttestation() throws SFException; +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProvider.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProvider.java new file mode 100644 index 0000000000..87b6bf9363 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProvider.java @@ -0,0 +1,78 @@ +package net.snowflake.client.core.auth.wif; + +import com.google.common.base.Strings; +import net.snowflake.client.core.SFException; +import net.snowflake.client.core.SnowflakeJdbcInternalApi; +import net.snowflake.client.jdbc.ErrorCode; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; + +@SnowflakeJdbcInternalApi +public class WorkloadIdentityAttestationProvider { + + private static final SFLogger logger = + SFLoggerFactory.getLogger(WorkloadIdentityAttestationProvider.class); + + private final AwsIdentityAttestationCreator awsAttestationCreator; + private final GcpIdentityAttestationCreator gcpAttestationCreator; + private final AzureIdentityAttestationCreator azureAttestationCreator; + private final OidcIdentityAttestationCreator oidcAttestationCreator; + + public WorkloadIdentityAttestationProvider( + AwsIdentityAttestationCreator awsAttestationCreator, + GcpIdentityAttestationCreator gcpAttestationCreator, + AzureIdentityAttestationCreator azureAttestationCreator, + OidcIdentityAttestationCreator oidcAttestationCreator) { + this.awsAttestationCreator = awsAttestationCreator; + this.gcpAttestationCreator = gcpAttestationCreator; + this.azureAttestationCreator = azureAttestationCreator; + this.oidcAttestationCreator = oidcAttestationCreator; + } + + public WorkloadIdentityAttestation getAttestation(String identityProvider) throws SFException { + if (Strings.isNullOrEmpty(identityProvider)) { + logger.debug("Workload Identity Provider has not been specified. Using autodetect..."); + return createAutodetectAttestation(); + } else { + return getCreator(identityProvider).createAttestation(); + } + } + + WorkloadIdentityAttestationCreator getCreator(String identityProvider) throws SFException { + if (WorkloadIdentityProviderType.AWS.name().equalsIgnoreCase(identityProvider)) { + return awsAttestationCreator; + } else if (WorkloadIdentityProviderType.GCP.name().equalsIgnoreCase(identityProvider)) { + return gcpAttestationCreator; + } else if (WorkloadIdentityProviderType.AZURE.name().equalsIgnoreCase(identityProvider)) { + return azureAttestationCreator; + } else if (WorkloadIdentityProviderType.OIDC.name().equalsIgnoreCase(identityProvider)) { + return oidcAttestationCreator; + } else { + throw new SFException( + ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR, + "Unknown Workload Identity provider specified: " + identityProvider); + } + } + + private WorkloadIdentityAttestation createAutodetectAttestation() throws SFException { + WorkloadIdentityAttestation awsAttestation = awsAttestationCreator.createAttestation(); + if (awsAttestation != null) { + return awsAttestation; + } + WorkloadIdentityAttestation gcpAttestation = gcpAttestationCreator.createAttestation(); + if (gcpAttestation != null) { + return gcpAttestation; + } + WorkloadIdentityAttestation azureAttestation = azureAttestationCreator.createAttestation(); + if (azureAttestation != null) { + return azureAttestation; + } + WorkloadIdentityAttestation oidcAttestation = oidcAttestationCreator.createAttestation(); + if (oidcAttestation != null) { + return oidcAttestation; + } + throw new SFException( + ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR, + "Unable to autodetect Workload Identity. None of supported Workload Identity environments has been identified."); + } +} diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityProviderType.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityProviderType.java new file mode 100644 index 0000000000..74b2c6855d --- /dev/null +++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityProviderType.java @@ -0,0 +1,9 @@ +package net.snowflake.client.core.auth.wif; + +enum WorkloadIdentityProviderType { + AWS, // Provider that builds an encoded pre-signed GetCallerIdentity request using the current + // workload's IAM role. + AZURE, // Provider that requests an OAuth access token for the workload's managed identity. + GCP, // Provider that requests an ID token for the workload's attached service account. + OIDC // Provider that looks for an OIDC ID token. +} diff --git a/src/main/java/net/snowflake/client/jdbc/EnvironmentVariables.java b/src/main/java/net/snowflake/client/jdbc/EnvironmentVariables.java new file mode 100644 index 0000000000..841b5e64c8 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/EnvironmentVariables.java @@ -0,0 +1,18 @@ +package net.snowflake.client.jdbc; + +import net.snowflake.client.core.SnowflakeJdbcInternalApi; + +@SnowflakeJdbcInternalApi +public enum EnvironmentVariables { + AWS_REGION("AWS_REGION"); + + private final String name; + + EnvironmentVariables(String name) { + this.name = name; + } + + public String getName() { + return name; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java index 884df03b70..fb1a373190 100644 --- a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java +++ b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java @@ -82,7 +82,8 @@ public enum ErrorCode { FILE_OPERATION_DOWNLOAD_ERROR(200067, SqlState.INTERNAL_ERROR), OAUTH_AUTHORIZATION_CODE_FLOW_ERROR(200068, SqlState.CONNECTION_EXCEPTION), OAUTH_CLIENT_CREDENTIALS_FLOW_ERROR(200069, SqlState.CONNECTION_EXCEPTION), - OAUTH_REFRESH_TOKEN_FLOW_ERROR(200070, SqlState.CONNECTION_EXCEPTION); + OAUTH_REFRESH_TOKEN_FLOW_ERROR(200070, SqlState.CONNECTION_EXCEPTION), + WORKFLOW_IDENTITY_FLOW_ERROR(200071, SqlState.CONNECTION_EXCEPTION); public static final String errorMessageResource = "net.snowflake.client.jdbc.jdbc_error_messages"; diff --git a/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties b/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties index c3a74d6942..3f5ec0d986 100644 --- a/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties +++ b/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties @@ -89,3 +89,4 @@ Error message={3}, Extended error info={4} 200068=Error during OAuth Authorization Code authentication: {0} 200069=Error during OAuth Client Credentials authentication: {0} 200070=Error during obtaining OAuth access token using refresh token: {0} +200071=Error during Workflow Identity authentication: {0} diff --git a/src/test/java/net/snowflake/client/core/SessionUtilTest.java b/src/test/java/net/snowflake/client/core/SessionUtilTest.java index 788864f545..25a3692352 100644 --- a/src/test/java/net/snowflake/client/core/SessionUtilTest.java +++ b/src/test/java/net/snowflake/client/core/SessionUtilTest.java @@ -1,8 +1,10 @@ package net.snowflake.client.core; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.mockStatic; @@ -22,7 +24,6 @@ import org.apache.http.client.methods.HttpPost; import org.apache.http.client.utils.URIBuilder; import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.mockito.MockedStatic; @@ -255,37 +256,42 @@ public void shouldProperlyCheckIfExperimentalAuthEnabled() { snowflakeUtilMockedStatic .when(() -> SnowflakeUtil.systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION")) .thenReturn(null); - Assertions.assertThrows( + assertThrows( SFException.class, () -> SessionUtil.checkIfExperimentalAuthnEnabled( AuthenticatorType.OAUTH_AUTHORIZATION_CODE)); - Assertions.assertThrows( + assertThrows( SFException.class, () -> SessionUtil.checkIfExperimentalAuthnEnabled( AuthenticatorType.OAUTH_CLIENT_CREDENTIALS)); - Assertions.assertThrows( + assertThrows( SFException.class, () -> SessionUtil.checkIfExperimentalAuthnEnabled( AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN)); + assertThrows( + SFException.class, + () -> SessionUtil.checkIfExperimentalAuthnEnabled(AuthenticatorType.WORKLOAD_IDENTITY)); snowflakeUtilMockedStatic .when(() -> SnowflakeUtil.systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION")) .thenReturn("true"); - Assertions.assertDoesNotThrow( + assertDoesNotThrow( () -> SessionUtil.checkIfExperimentalAuthnEnabled( AuthenticatorType.OAUTH_AUTHORIZATION_CODE)); - Assertions.assertDoesNotThrow( + assertDoesNotThrow( () -> SessionUtil.checkIfExperimentalAuthnEnabled( AuthenticatorType.OAUTH_CLIENT_CREDENTIALS)); - Assertions.assertDoesNotThrow( + assertDoesNotThrow( () -> SessionUtil.checkIfExperimentalAuthnEnabled( AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN)); + assertDoesNotThrow( + () -> SessionUtil.checkIfExperimentalAuthnEnabled(AuthenticatorType.WORKLOAD_IDENTITY)); } } diff --git a/src/test/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreatorTest.java b/src/test/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreatorTest.java new file mode 100644 index 0000000000..6d6abb4afe --- /dev/null +++ b/src/test/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreatorTest.java @@ -0,0 +1,87 @@ +package net.snowflake.client.core.auth.wif; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.amazonaws.auth.BasicAWSCredentials; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class AwsIdentityAttestationCreatorTest { + + @Test + public void shouldReturnNullWhenNoCredentialsFound() { + AWSAttestationService attestationServiceMock = Mockito.mock(AWSAttestationService.class); + Mockito.when(attestationServiceMock.getAWSCredentials()).thenReturn(null); + AwsIdentityAttestationCreator attestationCreator = + new AwsIdentityAttestationCreator(attestationServiceMock); + assertNull(attestationCreator.createAttestation()); + } + + @Test + public void shouldReturnNullWhenNoRegion() { + AWSAttestationService attestationServiceMock = Mockito.mock(AWSAttestationService.class); + Mockito.when(attestationServiceMock.getAWSCredentials()) + .thenReturn(new BasicAWSCredentials("abc", "abc")); + Mockito.when(attestationServiceMock.getAWSRegion()).thenReturn(null); + + AwsIdentityAttestationCreator attestationCreator = + new AwsIdentityAttestationCreator(attestationServiceMock); + assertNull(attestationCreator.createAttestation()); + } + + @Test + public void shouldReturnNullWhenNoCallerIdentity() { + AWSAttestationService attestationServiceMock = Mockito.mock(AWSAttestationService.class); + Mockito.when(attestationServiceMock.getAWSCredentials()) + .thenReturn(new BasicAWSCredentials("abc", "abc")); + Mockito.when(attestationServiceMock.getAWSRegion()).thenReturn("eu-west-1"); + Mockito.when(attestationServiceMock.getArn()).thenReturn(null); + + AwsIdentityAttestationCreator attestationCreator = + new AwsIdentityAttestationCreator(attestationServiceMock); + assertNull(attestationCreator.createAttestation()); + } + + @Test + public void shouldReturnProperAttestationWithSignedRequestCredential() + throws JsonProcessingException { + AWSAttestationService attestationServiceSpy = Mockito.spy(AWSAttestationService.class); + Mockito.doReturn(new BasicAWSCredentials("abc", "abc")) + .when(attestationServiceSpy) + .getAWSCredentials(); + Mockito.doReturn("eu-west-1").when(attestationServiceSpy).getAWSRegion(); + Mockito.doReturn("arn:aws:attestation:abc").when(attestationServiceSpy).getArn(); + + AwsIdentityAttestationCreator attestationCreator = + new AwsIdentityAttestationCreator(attestationServiceSpy); + WorkloadIdentityAttestation attestation = attestationCreator.createAttestation(); + + assertNotNull(attestation); + assertEquals(WorkloadIdentityProviderType.AWS, attestation.getProvider()); + assertEquals("arn:aws:attestation:abc", attestation.getUserIdentifiedComponents().get("arn")); + assertNotNull(attestation.getCredential()); + Base64.Decoder decoder = Base64.getDecoder(); + String json = new String(decoder.decode(attestation.getCredential())); + Map credentialMap = new ObjectMapper().readValue(json, HashMap.class); + assertEquals(3, credentialMap.size()); + assertEquals( + "https://sts.eu-west-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + credentialMap.get("url")); + assertEquals("POST", credentialMap.get("method")); + assertNotNull(credentialMap.get("headers")); + Map headersMap = (Map) credentialMap.get("headers"); + assertEquals(4, headersMap.size()); + assertEquals("sts.eu-west-1.amazonaws.com", headersMap.get("Host")); + assertEquals("snowflakecomputing.com", headersMap.get("X-Snowflake-Audience")); + assertNotNull(headersMap.get("X-Amz-Date")); + assertTrue(headersMap.get("Authorization").matches("^AWS4-HMAC-SHA256 Credential=.*")); + } +} diff --git a/src/test/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProviderTest.java b/src/test/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProviderTest.java new file mode 100644 index 0000000000..d650ada01a --- /dev/null +++ b/src/test/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProviderTest.java @@ -0,0 +1,54 @@ +package net.snowflake.client.core.auth.wif; + +import java.util.HashMap; +import net.snowflake.client.core.SFException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +class WorkloadIdentityAttestationProviderTest { + + @Test + public void shouldCreateAttestationWithExplicitAWSProvider() throws SFException { + AwsIdentityAttestationCreator awsCreatorMock = + Mockito.mock(AwsIdentityAttestationCreator.class); + Mockito.when(awsCreatorMock.createAttestation()) + .thenReturn( + new WorkloadIdentityAttestation( + WorkloadIdentityProviderType.AWS, "credential_abc", new HashMap<>())); + WorkloadIdentityAttestationProvider provider = + new WorkloadIdentityAttestationProvider(awsCreatorMock, null, null, null); + + WorkloadIdentityAttestation attestation = + provider.getAttestation(WorkloadIdentityProviderType.AWS.name()); + Assertions.assertNotNull(attestation); + Assertions.assertEquals(WorkloadIdentityProviderType.AWS, attestation.getProvider()); + Assertions.assertEquals("credential_abc", attestation.getCredential()); + Assertions.assertEquals(new HashMap<>(), attestation.getUserIdentifiedComponents()); + } + + @Test + public void shouldCreateProperAttestationCreatorByType() throws SFException { + WorkloadIdentityAttestationProvider provider = + new WorkloadIdentityAttestationProvider( + new AwsIdentityAttestationCreator(null), + new GcpIdentityAttestationCreator(), + new AzureIdentityAttestationCreator(), + new OidcIdentityAttestationCreator()); + WorkloadIdentityAttestationCreator attestationCreator = + provider.getCreator(WorkloadIdentityProviderType.AWS.name()); + Assertions.assertInstanceOf(AwsIdentityAttestationCreator.class, attestationCreator); + + attestationCreator = provider.getCreator(WorkloadIdentityProviderType.AZURE.name()); + Assertions.assertInstanceOf(AzureIdentityAttestationCreator.class, attestationCreator); + + attestationCreator = provider.getCreator(WorkloadIdentityProviderType.GCP.name()); + Assertions.assertInstanceOf(GcpIdentityAttestationCreator.class, attestationCreator); + + attestationCreator = provider.getCreator(WorkloadIdentityProviderType.OIDC.name()); + Assertions.assertInstanceOf(OidcIdentityAttestationCreator.class, attestationCreator); + + Assertions.assertThrows( + SFException.class, () -> provider.getCreator("UNKNOWN_IDENTITY_PROVIDER")); + } +} diff --git a/thin_public_pom.xml b/thin_public_pom.xml index 2265d14ab6..77bfaeea50 100644 --- a/thin_public_pom.xml +++ b/thin_public_pom.xml @@ -127,6 +127,10 @@ com.amazonaws aws-java-sdk-core + + com.amazonaws + aws-java-sdk-sts + com.amazonaws aws-java-sdk-kms