diff --git a/parent-pom.xml b/parent-pom.xml
index 3b1b41b014..257b2331a9 100644
--- a/parent-pom.xml
+++ b/parent-pom.xml
@@ -611,6 +611,10 @@
com.amazonaws
aws-java-sdk-s3
+
+ com.amazonaws
+ aws-java-sdk-sts
+
com.fasterxml.jackson.core
jackson-annotations
diff --git a/src/main/java/net/snowflake/client/core/SFLoginInput.java b/src/main/java/net/snowflake/client/core/SFLoginInput.java
index 72ed3157f2..21c24eadc3 100644
--- a/src/main/java/net/snowflake/client/core/SFLoginInput.java
+++ b/src/main/java/net/snowflake/client/core/SFLoginInput.java
@@ -5,7 +5,9 @@
import java.security.PrivateKey;
import java.time.Duration;
import java.util.Map;
+import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation;
import net.snowflake.client.jdbc.ErrorCode;
+import org.apache.http.client.methods.HttpRequestBase;
/** A class for holding all information required for login */
public class SFLoginInput {
@@ -57,6 +59,10 @@ public class SFLoginInput {
private boolean enableClientStoreTemporaryCredential;
private boolean enableClientRequestMfaToken;
+ // Workload Identity Federation
+ private String workloadIdentityProvider;
+ private WorkloadIdentityAttestation workloadIdentityAttestation;
+
// OAuth
private int redirectUriPort = -1;
private String clientId;
@@ -342,6 +348,15 @@ SFLoginInput setOauthRefreshToken(String oauthRefreshToken) {
return this;
}
+ String getWorkloadIdentityProvider() {
+ return workloadIdentityProvider;
+ }
+
+ SFLoginInput setWorkloadIdentityProvider(String workloadIdentityProvider) {
+ this.workloadIdentityProvider = workloadIdentityProvider;
+ return this;
+ }
+
@SnowflakeJdbcInternalApi
public String getDPoPPublicKey() {
return dpopPublicKey;
@@ -505,14 +520,11 @@ Map getAdditionalHttpHeadersForSnowsight() {
* Set additional http headers to apply to the outgoing request. The additional headers cannot be
* used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
- * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
- * Map)}
+ * HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
- * @see
- * HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
- * Map)
+ * @see HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map)
*/
public SFLoginInput setAdditionalHttpHeadersForSnowsight(
Map additionalHttpHeaders) {
@@ -589,4 +601,13 @@ SFLoginInput setOriginalAuthenticator(String originalAuthenticator) {
this.originalAuthenticator = originalAuthenticator;
return this;
}
+
+ public void setWorkloadIdentityAttestation(
+ WorkloadIdentityAttestation workloadIdentityAttestation) {
+ this.workloadIdentityAttestation = workloadIdentityAttestation;
+ }
+
+ public WorkloadIdentityAttestation getWorkloadIdentityAttestation() {
+ return workloadIdentityAttestation;
+ }
}
diff --git a/src/main/java/net/snowflake/client/core/SFSession.java b/src/main/java/net/snowflake/client/core/SFSession.java
index eafe2263d8..5a4fb1573a 100644
--- a/src/main/java/net/snowflake/client/core/SFSession.java
+++ b/src/main/java/net/snowflake/client/core/SFSession.java
@@ -721,6 +721,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setPrivateKey((PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY))
.setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE))
.setOauthLoginInput(oauthLoginInput)
+ .setWorkloadIdentityProvider(
+ (String) connectionPropertiesMap.get(SFSessionProperty.WORKLOAD_IDENTITY_PROVIDER))
.setPrivateKeyBase64(
(String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_BASE64))
.setPrivateKeyPwd(
diff --git a/src/main/java/net/snowflake/client/core/SFSessionProperty.java b/src/main/java/net/snowflake/client/core/SFSessionProperty.java
index bc20892d13..eba4f4ae01 100644
--- a/src/main/java/net/snowflake/client/core/SFSessionProperty.java
+++ b/src/main/java/net/snowflake/client/core/SFSessionProperty.java
@@ -26,6 +26,7 @@ public enum SFSessionProperty {
OAUTH_SCOPE("oauthScope", false, String.class),
OAUTH_AUTHORIZATION_URL("oauthAuthorizationUrl", false, String.class),
OAUTH_TOKEN_REQUEST_URL("oauthTokenRequestUrl", false, String.class),
+ WORKLOAD_IDENTITY_PROVIDER("workloadIdentityProvider", false, String.class),
WAREHOUSE("warehouse", false, String.class),
LOGIN_TIMEOUT("loginTimeout", false, Integer.class),
NETWORK_TIMEOUT("networkTimeout", false, Integer.class),
diff --git a/src/main/java/net/snowflake/client/core/SessionUtil.java b/src/main/java/net/snowflake/client/core/SessionUtil.java
index 275ad5a5a6..44e156784b 100644
--- a/src/main/java/net/snowflake/client/core/SessionUtil.java
+++ b/src/main/java/net/snowflake/client/core/SessionUtil.java
@@ -32,6 +32,13 @@
import net.snowflake.client.core.auth.oauth.OAuthAccessTokenForRefreshTokenProvider;
import net.snowflake.client.core.auth.oauth.OAuthAccessTokenProviderFactory;
import net.snowflake.client.core.auth.oauth.TokenResponseDTO;
+import net.snowflake.client.core.auth.wif.AWSAttestationService;
+import net.snowflake.client.core.auth.wif.AwsIdentityAttestationCreator;
+import net.snowflake.client.core.auth.wif.AzureIdentityAttestationCreator;
+import net.snowflake.client.core.auth.wif.GcpIdentityAttestationCreator;
+import net.snowflake.client.core.auth.wif.OidcIdentityAttestationCreator;
+import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation;
+import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestationProvider;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.RetryContext;
import net.snowflake.client.jdbc.RetryContextManager;
@@ -242,6 +249,10 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) {
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN.name())) {
return AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN;
+ } else if (loginInput
+ .getAuthenticator()
+ .equalsIgnoreCase(AuthenticatorType.WORKLOAD_IDENTITY.name())) {
+ return AuthenticatorType.WORKLOAD_IDENTITY;
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.SNOWFLAKE_JWT.name())) {
@@ -328,6 +339,24 @@ static SFLoginOutput openSession(
}
}
+ if (authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) {
+ WorkloadIdentityAttestationProvider attestationProvider =
+ new WorkloadIdentityAttestationProvider(
+ new AwsIdentityAttestationCreator(new AWSAttestationService()),
+ new GcpIdentityAttestationCreator(),
+ new AzureIdentityAttestationCreator(),
+ new OidcIdentityAttestationCreator());
+ WorkloadIdentityAttestation attestation =
+ attestationProvider.getAttestation(loginInput.getWorkloadIdentityProvider());
+ if (attestation != null) {
+ loginInput.setWorkloadIdentityAttestation(attestation);
+ } else {
+ throw new SFException(
+ ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR,
+ "Unable to obtain workload identity attestation. Make sure that correct workload identity provider has been set and that Snowflake-JDBC driver runs on supported environment.");
+ }
+ }
+
convertSessionParameterStringValueToBooleanIfGiven(loginInput, CLIENT_REQUEST_MFA_TOKEN);
readCachedCredentialsIfPossible(loginInput);
@@ -353,7 +382,8 @@ static SFLoginOutput openSession(
static void checkIfExperimentalAuthnEnabled(AuthenticatorType authenticator) throws SFException {
if (authenticator.equals(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN)
|| authenticator.equals(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS)
- || authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) {
+ || authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)
+ || authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) {
boolean experimentalAuthenticationMethodsEnabled =
Boolean.parseBoolean(systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"));
AssertUtil.assertTrue(
@@ -650,6 +680,15 @@ static SFLoginOutput newSession(
data.put(ClientAuthnParameter.OAUTH_TYPE.name(), loginInput.getOriginalAuthenticator());
}
+ if (authenticatorType == AuthenticatorType.WORKLOAD_IDENTITY) {
+ data.put(
+ ClientAuthnParameter.TOKEN.name(),
+ loginInput.getWorkloadIdentityAttestation().getCredential());
+ data.put(
+ ClientAuthnParameter.PROVIDER.name(),
+ loginInput.getWorkloadIdentityAttestation().getProvider());
+ }
+
// map of client environment parameters, including connection parameters
// and environment properties like OS version, etc.
Map clientEnv = new HashMap<>();
diff --git a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java
index b6a5919395..0251c65776 100644
--- a/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java
+++ b/src/main/java/net/snowflake/client/core/auth/AuthenticatorType.java
@@ -53,5 +53,10 @@ public enum AuthenticatorType {
/*
* Authenticator to support PAT created in Snowflake
*/
- PROGRAMMATIC_ACCESS_TOKEN
+ PROGRAMMATIC_ACCESS_TOKEN,
+
+ /*
+ * Authenticator to support existing authentication by existing AWS/GCP/Azure workload identity
+ */
+ WORKLOAD_IDENTITY
}
diff --git a/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java b/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java
index 6b8bc84935..2e4f2fcf83 100644
--- a/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java
+++ b/src/main/java/net/snowflake/client/core/auth/ClientAuthnParameter.java
@@ -18,5 +18,6 @@ public enum ClientAuthnParameter {
SESSION_PARAMETERS,
PROOF_KEY,
TOKEN,
- OAUTH_TYPE
+ OAUTH_TYPE,
+ PROVIDER
}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/AWSAttestationService.java b/src/main/java/net/snowflake/client/core/auth/wif/AWSAttestationService.java
new file mode 100644
index 0000000000..6a08485ca1
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/AWSAttestationService.java
@@ -0,0 +1,53 @@
+package net.snowflake.client.core.auth.wif;
+
+import com.amazonaws.SignableRequest;
+import com.amazonaws.auth.AWS4Signer;
+import com.amazonaws.auth.AWSCredentials;
+import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
+import com.amazonaws.regions.InstanceMetadataRegionProvider;
+import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
+import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest;
+import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult;
+import java.util.Optional;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+import net.snowflake.client.jdbc.EnvironmentVariables;
+import net.snowflake.client.jdbc.SnowflakeUtil;
+
+@SnowflakeJdbcInternalApi
+public class AWSAttestationService {
+
+ private static final String SECURE_TOKEN_SERVICE_NAME = "sts";
+ private static boolean regionInitialized = false;
+ private static String region;
+
+ private final AWS4Signer aws4Signer;
+
+ public AWSAttestationService() {
+ aws4Signer = new AWS4Signer();
+ aws4Signer.setServiceName(SECURE_TOKEN_SERVICE_NAME);
+ }
+
+ AWSCredentials getAWSCredentials() {
+ return DefaultAWSCredentialsProviderChain.getInstance().getCredentials();
+ }
+
+ String getAWSRegion() {
+ if (!regionInitialized) {
+ String envRegion = SnowflakeUtil.systemGetEnv(EnvironmentVariables.AWS_REGION.getName());
+ region = envRegion != null ? envRegion : new InstanceMetadataRegionProvider().getRegion();
+ regionInitialized = true;
+ }
+ return region;
+ }
+
+ String getArn() {
+ GetCallerIdentityResult callerIdentity =
+ AWSSecurityTokenServiceClientBuilder.defaultClient()
+ .getCallerIdentity(new GetCallerIdentityRequest());
+ return Optional.ofNullable(callerIdentity).map(GetCallerIdentityResult::getArn).orElse(null);
+ }
+
+ void signRequestWithSigV4(SignableRequest signableRequest, AWSCredentials awsCredentials) {
+ aws4Signer.sign(signableRequest, awsCredentials);
+ }
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreator.java
new file mode 100644
index 0000000000..b38f070d83
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreator.java
@@ -0,0 +1,81 @@
+package net.snowflake.client.core.auth.wif;
+
+import com.amazonaws.DefaultRequest;
+import com.amazonaws.Request;
+import com.amazonaws.auth.AWSCredentials;
+import com.amazonaws.http.HttpMethodName;
+import java.net.URI;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Collections;
+import net.minidev.json.JSONObject;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+import net.snowflake.client.log.SFLogger;
+import net.snowflake.client.log.SFLoggerFactory;
+
+@SnowflakeJdbcInternalApi
+public class AwsIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {
+
+ private static final SFLogger logger =
+ SFLoggerFactory.getLogger(AwsIdentityAttestationCreator.class);
+
+ private static final String SNOWFLAKE_AUDIENCE_HEADER_NAME = "X-Snowflake-Audience";
+ private static final String SNOWFLAKE_AUDIENCE = "snowflakecomputing.com";
+
+ private final AWSAttestationService attestationService;
+
+ public AwsIdentityAttestationCreator(AWSAttestationService attestationService) {
+ this.attestationService = attestationService;
+ }
+
+ @Override
+ public WorkloadIdentityAttestation createAttestation() {
+ logger.debug("Creating AWS identity attestation...");
+ AWSCredentials awsCredentials = attestationService.getAWSCredentials();
+ if (awsCredentials == null) {
+ logger.debug("No AWS credentials were found.");
+ return null;
+ }
+ String region = attestationService.getAWSRegion();
+ if (region == null) {
+ logger.debug("No AWS region was found.");
+ return null;
+ }
+ String arn = attestationService.getArn();
+ if (arn == null) {
+ logger.debug("No Caller Identity was found.");
+ return null;
+ }
+
+ String stsHostname = String.format("sts.%s.amazonaws.com", region);
+ Request request = createStsRequest(stsHostname);
+ attestationService.signRequestWithSigV4(request, awsCredentials);
+
+ String credential = createBase64EncodedRequestCredential(request);
+ return new WorkloadIdentityAttestation(
+ WorkloadIdentityProviderType.AWS, credential, Collections.singletonMap("arn", arn));
+ }
+
+ private Request createStsRequest(String hostname) {
+ Request request = new DefaultRequest<>("sts");
+ request.setHttpMethod(HttpMethodName.POST);
+ request.setEndpoint(
+ URI.create(
+ String.format("https://%s/?Action=GetCallerIdentity&Version=2011-06-15", hostname)));
+ request.addHeader("Host", hostname);
+ request.addHeader(SNOWFLAKE_AUDIENCE_HEADER_NAME, SNOWFLAKE_AUDIENCE);
+ return request;
+ }
+
+ private String createBase64EncodedRequestCredential(Request request) {
+ JSONObject assertionJson = new JSONObject();
+ JSONObject headers = new JSONObject();
+ headers.putAll(request.getHeaders());
+ assertionJson.put("url", request.getEndpoint().toString());
+ assertionJson.put("method", request.getHttpMethod().toString());
+ assertionJson.put("headers", headers);
+
+ String assertionJsonString = assertionJson.toString();
+ return Base64.getEncoder().encodeToString(assertionJsonString.getBytes(StandardCharsets.UTF_8));
+ }
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/AzureIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/AzureIdentityAttestationCreator.java
new file mode 100644
index 0000000000..2ff1a9c1e2
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/AzureIdentityAttestationCreator.java
@@ -0,0 +1,19 @@
+package net.snowflake.client.core.auth.wif;
+
+import net.snowflake.client.core.SFException;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+import net.snowflake.client.jdbc.ErrorCode;
+import net.snowflake.client.log.SFLogger;
+import net.snowflake.client.log.SFLoggerFactory;
+
+@SnowflakeJdbcInternalApi
+public class AzureIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {
+
+ private static final SFLogger logger =
+ SFLoggerFactory.getLogger(AzureIdentityAttestationCreator.class);
+
+ @Override
+ public WorkloadIdentityAttestation createAttestation() throws SFException {
+ throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "Azure Workload Identity not supported");
+ }
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/GcpIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/GcpIdentityAttestationCreator.java
new file mode 100644
index 0000000000..c58c6633ca
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/GcpIdentityAttestationCreator.java
@@ -0,0 +1,19 @@
+package net.snowflake.client.core.auth.wif;
+
+import net.snowflake.client.core.SFException;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+import net.snowflake.client.jdbc.ErrorCode;
+import net.snowflake.client.log.SFLogger;
+import net.snowflake.client.log.SFLoggerFactory;
+
+@SnowflakeJdbcInternalApi
+public class GcpIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {
+
+ private static final SFLogger logger =
+ SFLoggerFactory.getLogger(GcpIdentityAttestationCreator.class);
+
+ @Override
+ public WorkloadIdentityAttestation createAttestation() throws SFException {
+ throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "GCP Workload Identity not supported");
+ }
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/OidcIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/OidcIdentityAttestationCreator.java
new file mode 100644
index 0000000000..edbae18b29
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/OidcIdentityAttestationCreator.java
@@ -0,0 +1,19 @@
+package net.snowflake.client.core.auth.wif;
+
+import net.snowflake.client.core.SFException;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+import net.snowflake.client.jdbc.ErrorCode;
+import net.snowflake.client.log.SFLogger;
+import net.snowflake.client.log.SFLoggerFactory;
+
+@SnowflakeJdbcInternalApi
+public class OidcIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {
+
+ private static final SFLogger logger =
+ SFLoggerFactory.getLogger(OidcIdentityAttestationCreator.class);
+
+ @Override
+ public WorkloadIdentityAttestation createAttestation() throws SFException {
+ throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "OIDC Workload Identity not supported");
+ }
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestation.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestation.java
new file mode 100644
index 0000000000..daf2cc32c6
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestation.java
@@ -0,0 +1,33 @@
+package net.snowflake.client.core.auth.wif;
+
+import java.util.Map;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+
+@SnowflakeJdbcInternalApi
+public class WorkloadIdentityAttestation {
+
+ private final WorkloadIdentityProviderType provider;
+ private final String credential;
+ private final Map userIdentifiedComponents;
+
+ WorkloadIdentityAttestation(
+ WorkloadIdentityProviderType provider,
+ String credential,
+ Map userIdentifiedComponents) {
+ this.provider = provider;
+ this.credential = credential;
+ this.userIdentifiedComponents = userIdentifiedComponents;
+ }
+
+ public WorkloadIdentityProviderType getProvider() {
+ return provider;
+ }
+
+ public String getCredential() {
+ return credential;
+ }
+
+ public Map getUserIdentifiedComponents() {
+ return userIdentifiedComponents;
+ }
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationCreator.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationCreator.java
new file mode 100644
index 0000000000..736f6a8400
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationCreator.java
@@ -0,0 +1,8 @@
+package net.snowflake.client.core.auth.wif;
+
+import net.snowflake.client.core.SFException;
+
+interface WorkloadIdentityAttestationCreator {
+
+ WorkloadIdentityAttestation createAttestation() throws SFException;
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProvider.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProvider.java
new file mode 100644
index 0000000000..87b6bf9363
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProvider.java
@@ -0,0 +1,78 @@
+package net.snowflake.client.core.auth.wif;
+
+import com.google.common.base.Strings;
+import net.snowflake.client.core.SFException;
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+import net.snowflake.client.jdbc.ErrorCode;
+import net.snowflake.client.log.SFLogger;
+import net.snowflake.client.log.SFLoggerFactory;
+
+@SnowflakeJdbcInternalApi
+public class WorkloadIdentityAttestationProvider {
+
+ private static final SFLogger logger =
+ SFLoggerFactory.getLogger(WorkloadIdentityAttestationProvider.class);
+
+ private final AwsIdentityAttestationCreator awsAttestationCreator;
+ private final GcpIdentityAttestationCreator gcpAttestationCreator;
+ private final AzureIdentityAttestationCreator azureAttestationCreator;
+ private final OidcIdentityAttestationCreator oidcAttestationCreator;
+
+ public WorkloadIdentityAttestationProvider(
+ AwsIdentityAttestationCreator awsAttestationCreator,
+ GcpIdentityAttestationCreator gcpAttestationCreator,
+ AzureIdentityAttestationCreator azureAttestationCreator,
+ OidcIdentityAttestationCreator oidcAttestationCreator) {
+ this.awsAttestationCreator = awsAttestationCreator;
+ this.gcpAttestationCreator = gcpAttestationCreator;
+ this.azureAttestationCreator = azureAttestationCreator;
+ this.oidcAttestationCreator = oidcAttestationCreator;
+ }
+
+ public WorkloadIdentityAttestation getAttestation(String identityProvider) throws SFException {
+ if (Strings.isNullOrEmpty(identityProvider)) {
+ logger.debug("Workload Identity Provider has not been specified. Using autodetect...");
+ return createAutodetectAttestation();
+ } else {
+ return getCreator(identityProvider).createAttestation();
+ }
+ }
+
+ WorkloadIdentityAttestationCreator getCreator(String identityProvider) throws SFException {
+ if (WorkloadIdentityProviderType.AWS.name().equalsIgnoreCase(identityProvider)) {
+ return awsAttestationCreator;
+ } else if (WorkloadIdentityProviderType.GCP.name().equalsIgnoreCase(identityProvider)) {
+ return gcpAttestationCreator;
+ } else if (WorkloadIdentityProviderType.AZURE.name().equalsIgnoreCase(identityProvider)) {
+ return azureAttestationCreator;
+ } else if (WorkloadIdentityProviderType.OIDC.name().equalsIgnoreCase(identityProvider)) {
+ return oidcAttestationCreator;
+ } else {
+ throw new SFException(
+ ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR,
+ "Unknown Workload Identity provider specified: " + identityProvider);
+ }
+ }
+
+ private WorkloadIdentityAttestation createAutodetectAttestation() throws SFException {
+ WorkloadIdentityAttestation awsAttestation = awsAttestationCreator.createAttestation();
+ if (awsAttestation != null) {
+ return awsAttestation;
+ }
+ WorkloadIdentityAttestation gcpAttestation = gcpAttestationCreator.createAttestation();
+ if (gcpAttestation != null) {
+ return gcpAttestation;
+ }
+ WorkloadIdentityAttestation azureAttestation = azureAttestationCreator.createAttestation();
+ if (azureAttestation != null) {
+ return azureAttestation;
+ }
+ WorkloadIdentityAttestation oidcAttestation = oidcAttestationCreator.createAttestation();
+ if (oidcAttestation != null) {
+ return oidcAttestation;
+ }
+ throw new SFException(
+ ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR,
+ "Unable to autodetect Workload Identity. None of supported Workload Identity environments has been identified.");
+ }
+}
diff --git a/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityProviderType.java b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityProviderType.java
new file mode 100644
index 0000000000..74b2c6855d
--- /dev/null
+++ b/src/main/java/net/snowflake/client/core/auth/wif/WorkloadIdentityProviderType.java
@@ -0,0 +1,9 @@
+package net.snowflake.client.core.auth.wif;
+
+enum WorkloadIdentityProviderType {
+ AWS, // Provider that builds an encoded pre-signed GetCallerIdentity request using the current
+ // workload's IAM role.
+ AZURE, // Provider that requests an OAuth access token for the workload's managed identity.
+ GCP, // Provider that requests an ID token for the workload's attached service account.
+ OIDC // Provider that looks for an OIDC ID token.
+}
diff --git a/src/main/java/net/snowflake/client/jdbc/EnvironmentVariables.java b/src/main/java/net/snowflake/client/jdbc/EnvironmentVariables.java
new file mode 100644
index 0000000000..841b5e64c8
--- /dev/null
+++ b/src/main/java/net/snowflake/client/jdbc/EnvironmentVariables.java
@@ -0,0 +1,18 @@
+package net.snowflake.client.jdbc;
+
+import net.snowflake.client.core.SnowflakeJdbcInternalApi;
+
+@SnowflakeJdbcInternalApi
+public enum EnvironmentVariables {
+ AWS_REGION("AWS_REGION");
+
+ private final String name;
+
+ EnvironmentVariables(String name) {
+ this.name = name;
+ }
+
+ public String getName() {
+ return name;
+ }
+}
diff --git a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java
index 884df03b70..fb1a373190 100644
--- a/src/main/java/net/snowflake/client/jdbc/ErrorCode.java
+++ b/src/main/java/net/snowflake/client/jdbc/ErrorCode.java
@@ -82,7 +82,8 @@ public enum ErrorCode {
FILE_OPERATION_DOWNLOAD_ERROR(200067, SqlState.INTERNAL_ERROR),
OAUTH_AUTHORIZATION_CODE_FLOW_ERROR(200068, SqlState.CONNECTION_EXCEPTION),
OAUTH_CLIENT_CREDENTIALS_FLOW_ERROR(200069, SqlState.CONNECTION_EXCEPTION),
- OAUTH_REFRESH_TOKEN_FLOW_ERROR(200070, SqlState.CONNECTION_EXCEPTION);
+ OAUTH_REFRESH_TOKEN_FLOW_ERROR(200070, SqlState.CONNECTION_EXCEPTION),
+ WORKFLOW_IDENTITY_FLOW_ERROR(200071, SqlState.CONNECTION_EXCEPTION);
public static final String errorMessageResource = "net.snowflake.client.jdbc.jdbc_error_messages";
diff --git a/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties b/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties
index c3a74d6942..3f5ec0d986 100644
--- a/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties
+++ b/src/main/resources/net/snowflake/client/jdbc/jdbc_error_messages.properties
@@ -89,3 +89,4 @@ Error message={3}, Extended error info={4}
200068=Error during OAuth Authorization Code authentication: {0}
200069=Error during OAuth Client Credentials authentication: {0}
200070=Error during obtaining OAuth access token using refresh token: {0}
+200071=Error during Workflow Identity authentication: {0}
diff --git a/src/test/java/net/snowflake/client/core/SessionUtilTest.java b/src/test/java/net/snowflake/client/core/SessionUtilTest.java
index 788864f545..25a3692352 100644
--- a/src/test/java/net/snowflake/client/core/SessionUtilTest.java
+++ b/src/test/java/net/snowflake/client/core/SessionUtilTest.java
@@ -1,8 +1,10 @@
package net.snowflake.client.core;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.mockStatic;
@@ -22,7 +24,6 @@
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.junit.jupiter.api.AfterAll;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
@@ -255,37 +256,42 @@ public void shouldProperlyCheckIfExperimentalAuthEnabled() {
snowflakeUtilMockedStatic
.when(() -> SnowflakeUtil.systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"))
.thenReturn(null);
- Assertions.assertThrows(
+ assertThrows(
SFException.class,
() ->
SessionUtil.checkIfExperimentalAuthnEnabled(
AuthenticatorType.OAUTH_AUTHORIZATION_CODE));
- Assertions.assertThrows(
+ assertThrows(
SFException.class,
() ->
SessionUtil.checkIfExperimentalAuthnEnabled(
AuthenticatorType.OAUTH_CLIENT_CREDENTIALS));
- Assertions.assertThrows(
+ assertThrows(
SFException.class,
() ->
SessionUtil.checkIfExperimentalAuthnEnabled(
AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN));
+ assertThrows(
+ SFException.class,
+ () -> SessionUtil.checkIfExperimentalAuthnEnabled(AuthenticatorType.WORKLOAD_IDENTITY));
snowflakeUtilMockedStatic
.when(() -> SnowflakeUtil.systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"))
.thenReturn("true");
- Assertions.assertDoesNotThrow(
+ assertDoesNotThrow(
() ->
SessionUtil.checkIfExperimentalAuthnEnabled(
AuthenticatorType.OAUTH_AUTHORIZATION_CODE));
- Assertions.assertDoesNotThrow(
+ assertDoesNotThrow(
() ->
SessionUtil.checkIfExperimentalAuthnEnabled(
AuthenticatorType.OAUTH_CLIENT_CREDENTIALS));
- Assertions.assertDoesNotThrow(
+ assertDoesNotThrow(
() ->
SessionUtil.checkIfExperimentalAuthnEnabled(
AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN));
+ assertDoesNotThrow(
+ () -> SessionUtil.checkIfExperimentalAuthnEnabled(AuthenticatorType.WORKLOAD_IDENTITY));
}
}
diff --git a/src/test/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreatorTest.java b/src/test/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreatorTest.java
new file mode 100644
index 0000000000..6d6abb4afe
--- /dev/null
+++ b/src/test/java/net/snowflake/client/core/auth/wif/AwsIdentityAttestationCreatorTest.java
@@ -0,0 +1,87 @@
+package net.snowflake.client.core.auth.wif;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.amazonaws.auth.BasicAWSCredentials;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import java.util.Base64;
+import java.util.HashMap;
+import java.util.Map;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+public class AwsIdentityAttestationCreatorTest {
+
+ @Test
+ public void shouldReturnNullWhenNoCredentialsFound() {
+ AWSAttestationService attestationServiceMock = Mockito.mock(AWSAttestationService.class);
+ Mockito.when(attestationServiceMock.getAWSCredentials()).thenReturn(null);
+ AwsIdentityAttestationCreator attestationCreator =
+ new AwsIdentityAttestationCreator(attestationServiceMock);
+ assertNull(attestationCreator.createAttestation());
+ }
+
+ @Test
+ public void shouldReturnNullWhenNoRegion() {
+ AWSAttestationService attestationServiceMock = Mockito.mock(AWSAttestationService.class);
+ Mockito.when(attestationServiceMock.getAWSCredentials())
+ .thenReturn(new BasicAWSCredentials("abc", "abc"));
+ Mockito.when(attestationServiceMock.getAWSRegion()).thenReturn(null);
+
+ AwsIdentityAttestationCreator attestationCreator =
+ new AwsIdentityAttestationCreator(attestationServiceMock);
+ assertNull(attestationCreator.createAttestation());
+ }
+
+ @Test
+ public void shouldReturnNullWhenNoCallerIdentity() {
+ AWSAttestationService attestationServiceMock = Mockito.mock(AWSAttestationService.class);
+ Mockito.when(attestationServiceMock.getAWSCredentials())
+ .thenReturn(new BasicAWSCredentials("abc", "abc"));
+ Mockito.when(attestationServiceMock.getAWSRegion()).thenReturn("eu-west-1");
+ Mockito.when(attestationServiceMock.getArn()).thenReturn(null);
+
+ AwsIdentityAttestationCreator attestationCreator =
+ new AwsIdentityAttestationCreator(attestationServiceMock);
+ assertNull(attestationCreator.createAttestation());
+ }
+
+ @Test
+ public void shouldReturnProperAttestationWithSignedRequestCredential()
+ throws JsonProcessingException {
+ AWSAttestationService attestationServiceSpy = Mockito.spy(AWSAttestationService.class);
+ Mockito.doReturn(new BasicAWSCredentials("abc", "abc"))
+ .when(attestationServiceSpy)
+ .getAWSCredentials();
+ Mockito.doReturn("eu-west-1").when(attestationServiceSpy).getAWSRegion();
+ Mockito.doReturn("arn:aws:attestation:abc").when(attestationServiceSpy).getArn();
+
+ AwsIdentityAttestationCreator attestationCreator =
+ new AwsIdentityAttestationCreator(attestationServiceSpy);
+ WorkloadIdentityAttestation attestation = attestationCreator.createAttestation();
+
+ assertNotNull(attestation);
+ assertEquals(WorkloadIdentityProviderType.AWS, attestation.getProvider());
+ assertEquals("arn:aws:attestation:abc", attestation.getUserIdentifiedComponents().get("arn"));
+ assertNotNull(attestation.getCredential());
+ Base64.Decoder decoder = Base64.getDecoder();
+ String json = new String(decoder.decode(attestation.getCredential()));
+ Map credentialMap = new ObjectMapper().readValue(json, HashMap.class);
+ assertEquals(3, credentialMap.size());
+ assertEquals(
+ "https://sts.eu-west-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15",
+ credentialMap.get("url"));
+ assertEquals("POST", credentialMap.get("method"));
+ assertNotNull(credentialMap.get("headers"));
+ Map headersMap = (Map) credentialMap.get("headers");
+ assertEquals(4, headersMap.size());
+ assertEquals("sts.eu-west-1.amazonaws.com", headersMap.get("Host"));
+ assertEquals("snowflakecomputing.com", headersMap.get("X-Snowflake-Audience"));
+ assertNotNull(headersMap.get("X-Amz-Date"));
+ assertTrue(headersMap.get("Authorization").matches("^AWS4-HMAC-SHA256 Credential=.*"));
+ }
+}
diff --git a/src/test/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProviderTest.java b/src/test/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProviderTest.java
new file mode 100644
index 0000000000..d650ada01a
--- /dev/null
+++ b/src/test/java/net/snowflake/client/core/auth/wif/WorkloadIdentityAttestationProviderTest.java
@@ -0,0 +1,54 @@
+package net.snowflake.client.core.auth.wif;
+
+import java.util.HashMap;
+import net.snowflake.client.core.SFException;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+class WorkloadIdentityAttestationProviderTest {
+
+ @Test
+ public void shouldCreateAttestationWithExplicitAWSProvider() throws SFException {
+ AwsIdentityAttestationCreator awsCreatorMock =
+ Mockito.mock(AwsIdentityAttestationCreator.class);
+ Mockito.when(awsCreatorMock.createAttestation())
+ .thenReturn(
+ new WorkloadIdentityAttestation(
+ WorkloadIdentityProviderType.AWS, "credential_abc", new HashMap<>()));
+ WorkloadIdentityAttestationProvider provider =
+ new WorkloadIdentityAttestationProvider(awsCreatorMock, null, null, null);
+
+ WorkloadIdentityAttestation attestation =
+ provider.getAttestation(WorkloadIdentityProviderType.AWS.name());
+ Assertions.assertNotNull(attestation);
+ Assertions.assertEquals(WorkloadIdentityProviderType.AWS, attestation.getProvider());
+ Assertions.assertEquals("credential_abc", attestation.getCredential());
+ Assertions.assertEquals(new HashMap<>(), attestation.getUserIdentifiedComponents());
+ }
+
+ @Test
+ public void shouldCreateProperAttestationCreatorByType() throws SFException {
+ WorkloadIdentityAttestationProvider provider =
+ new WorkloadIdentityAttestationProvider(
+ new AwsIdentityAttestationCreator(null),
+ new GcpIdentityAttestationCreator(),
+ new AzureIdentityAttestationCreator(),
+ new OidcIdentityAttestationCreator());
+ WorkloadIdentityAttestationCreator attestationCreator =
+ provider.getCreator(WorkloadIdentityProviderType.AWS.name());
+ Assertions.assertInstanceOf(AwsIdentityAttestationCreator.class, attestationCreator);
+
+ attestationCreator = provider.getCreator(WorkloadIdentityProviderType.AZURE.name());
+ Assertions.assertInstanceOf(AzureIdentityAttestationCreator.class, attestationCreator);
+
+ attestationCreator = provider.getCreator(WorkloadIdentityProviderType.GCP.name());
+ Assertions.assertInstanceOf(GcpIdentityAttestationCreator.class, attestationCreator);
+
+ attestationCreator = provider.getCreator(WorkloadIdentityProviderType.OIDC.name());
+ Assertions.assertInstanceOf(OidcIdentityAttestationCreator.class, attestationCreator);
+
+ Assertions.assertThrows(
+ SFException.class, () -> provider.getCreator("UNKNOWN_IDENTITY_PROVIDER"));
+ }
+}
diff --git a/thin_public_pom.xml b/thin_public_pom.xml
index 2265d14ab6..77bfaeea50 100644
--- a/thin_public_pom.xml
+++ b/thin_public_pom.xml
@@ -127,6 +127,10 @@
com.amazonaws
aws-java-sdk-core
+
+ com.amazonaws
+ aws-java-sdk-sts
+
com.amazonaws
aws-java-sdk-kms