Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions parent-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-s3</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sts</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
Expand Down
31 changes: 26 additions & 5 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import java.security.PrivateKey;
import java.time.Duration;
import java.util.Map;
import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation;
import net.snowflake.client.jdbc.ErrorCode;
import org.apache.http.client.methods.HttpRequestBase;

/** A class for holding all information required for login */
public class SFLoginInput {
Expand Down Expand Up @@ -57,6 +59,10 @@ public class SFLoginInput {
private boolean enableClientStoreTemporaryCredential;
private boolean enableClientRequestMfaToken;

// Workload Identity Federation
private String workloadIdentityProvider;
private WorkloadIdentityAttestation workloadIdentityAttestation;

// OAuth
private int redirectUriPort = -1;
private String clientId;
Expand Down Expand Up @@ -342,6 +348,15 @@ SFLoginInput setOauthRefreshToken(String oauthRefreshToken) {
return this;
}

String getWorkloadIdentityProvider() {
return workloadIdentityProvider;
}

SFLoginInput setWorkloadIdentityProvider(String workloadIdentityProvider) {
this.workloadIdentityProvider = workloadIdentityProvider;
return this;
}

@SnowflakeJdbcInternalApi
public String getDPoPPublicKey() {
return dpopPublicKey;
Expand Down Expand Up @@ -505,14 +520,11 @@ Map<String, String> getAdditionalHttpHeadersForSnowsight() {
* Set additional http headers to apply to the outgoing request. The additional headers cannot be
* used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
* HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
* @see HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map)
*/
public SFLoginInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
Expand Down Expand Up @@ -589,4 +601,13 @@ SFLoginInput setOriginalAuthenticator(String originalAuthenticator) {
this.originalAuthenticator = originalAuthenticator;
return this;
}

public void setWorkloadIdentityAttestation(
WorkloadIdentityAttestation workloadIdentityAttestation) {
this.workloadIdentityAttestation = workloadIdentityAttestation;
}

public WorkloadIdentityAttestation getWorkloadIdentityAttestation() {
return workloadIdentityAttestation;
}
}
2 changes: 2 additions & 0 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setPrivateKey((PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY))
.setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE))
.setOauthLoginInput(oauthLoginInput)
.setWorkloadIdentityProvider(
(String) connectionPropertiesMap.get(SFSessionProperty.WORKLOAD_IDENTITY_PROVIDER))
.setPrivateKeyBase64(
(String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_BASE64))
.setPrivateKeyPwd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public enum SFSessionProperty {
OAUTH_SCOPE("oauthScope", false, String.class),
OAUTH_AUTHORIZATION_URL("oauthAuthorizationUrl", false, String.class),
OAUTH_TOKEN_REQUEST_URL("oauthTokenRequestUrl", false, String.class),
WORKLOAD_IDENTITY_PROVIDER("workloadIdentityProvider", false, String.class),
WAREHOUSE("warehouse", false, String.class),
LOGIN_TIMEOUT("loginTimeout", false, Integer.class),
NETWORK_TIMEOUT("networkTimeout", false, Integer.class),
Expand Down
41 changes: 40 additions & 1 deletion src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
import net.snowflake.client.core.auth.oauth.OAuthAccessTokenForRefreshTokenProvider;
import net.snowflake.client.core.auth.oauth.OAuthAccessTokenProviderFactory;
import net.snowflake.client.core.auth.oauth.TokenResponseDTO;
import net.snowflake.client.core.auth.wif.AWSAttestationService;
import net.snowflake.client.core.auth.wif.AwsIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.AzureIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.GcpIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.OidcIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation;
import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestationProvider;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.RetryContext;
import net.snowflake.client.jdbc.RetryContextManager;
Expand Down Expand Up @@ -242,6 +249,10 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) {
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN.name())) {
return AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN;
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.WORKLOAD_IDENTITY.name())) {
return AuthenticatorType.WORKLOAD_IDENTITY;
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.SNOWFLAKE_JWT.name())) {
Expand Down Expand Up @@ -328,6 +339,24 @@ static SFLoginOutput openSession(
}
}

if (authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) {
WorkloadIdentityAttestationProvider attestationProvider =
new WorkloadIdentityAttestationProvider(
new AwsIdentityAttestationCreator(new AWSAttestationService()),
new GcpIdentityAttestationCreator(),
new AzureIdentityAttestationCreator(),
new OidcIdentityAttestationCreator());
WorkloadIdentityAttestation attestation =
attestationProvider.getAttestation(loginInput.getWorkloadIdentityProvider());
if (attestation != null) {
loginInput.setWorkloadIdentityAttestation(attestation);
} else {
throw new SFException(
ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR,
"Unable to obtain workload identity attestation. Make sure that correct workload identity provider has been set and that Snowflake-JDBC driver runs on supported environment.");
}
}

convertSessionParameterStringValueToBooleanIfGiven(loginInput, CLIENT_REQUEST_MFA_TOKEN);

readCachedCredentialsIfPossible(loginInput);
Expand All @@ -353,7 +382,8 @@ static SFLoginOutput openSession(
static void checkIfExperimentalAuthnEnabled(AuthenticatorType authenticator) throws SFException {
if (authenticator.equals(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN)
|| authenticator.equals(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS)
|| authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) {
|| authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)
|| authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) {
boolean experimentalAuthenticationMethodsEnabled =
Boolean.parseBoolean(systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"));
AssertUtil.assertTrue(
Expand Down Expand Up @@ -650,6 +680,15 @@ static SFLoginOutput newSession(
data.put(ClientAuthnParameter.OAUTH_TYPE.name(), loginInput.getOriginalAuthenticator());
}

if (authenticatorType == AuthenticatorType.WORKLOAD_IDENTITY) {
data.put(
ClientAuthnParameter.TOKEN.name(),
loginInput.getWorkloadIdentityAttestation().getCredential());
data.put(
ClientAuthnParameter.PROVIDER.name(),
loginInput.getWorkloadIdentityAttestation().getProvider());
}

// map of client environment parameters, including connection parameters
// and environment properties like OS version, etc.
Map<String, Object> clientEnv = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,10 @@ public enum AuthenticatorType {
/*
* Authenticator to support PAT created in Snowflake
*/
PROGRAMMATIC_ACCESS_TOKEN
PROGRAMMATIC_ACCESS_TOKEN,

/*
* Authenticator to support existing authentication by existing AWS/GCP/Azure workload identity
*/
WORKLOAD_IDENTITY
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ public enum ClientAuthnParameter {
SESSION_PARAMETERS,
PROOF_KEY,
TOKEN,
OAUTH_TYPE
OAUTH_TYPE,
PROVIDER
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package net.snowflake.client.core.auth.wif;

import com.amazonaws.SignableRequest;
import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.regions.InstanceMetadataRegionProvider;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.SnowflakeUtil;

@SnowflakeJdbcInternalApi
public class AWSAttestationService {

AWSCredentials getAWSCredentials() {
return DefaultAWSCredentialsProviderChain.getInstance().getCredentials();
}

String getAWSRegion() {
String region = SnowflakeUtil.systemGetEnv("AWS_REGION");
if (region != null) {
return region;
} else {
return new InstanceMetadataRegionProvider().getRegion();
Comment thread
sfc-gh-pfus marked this conversation as resolved.
Outdated
}
}

String getArn() {
GetCallerIdentityResult callerIdentity =
AWSSecurityTokenServiceClientBuilder.defaultClient()
.getCallerIdentity(new GetCallerIdentityRequest());
if (callerIdentity != null) {
Comment thread
sfc-gh-pfus marked this conversation as resolved.
Outdated
return callerIdentity.getArn();
} else {
return null;
}
}

void signRequestWithSigV4(SignableRequest<Void> signableRequest, AWSCredentials awsCredentials) {
AWS4Signer sigV4Signer = new AWS4Signer();
Comment thread
sfc-gh-dheyman marked this conversation as resolved.
Outdated
sigV4Signer.setServiceName("sts");
sigV4Signer.sign(signableRequest, awsCredentials);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package net.snowflake.client.core.auth.wif;

import com.amazonaws.DefaultRequest;
import com.amazonaws.Request;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.http.HttpMethodName;
import com.google.gson.JsonObject;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class AwsIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(AwsIdentityAttestationCreator.class);

private static final String SNOWFLAKE_AUDIENCE_HEADER_NAME = "X-Snowflake-Audience";
private static final String SNOWFLAKE_AUDIENCE = "snowflakecomputing.com";

private final AWSAttestationService attestationService;
Comment thread
sfc-gh-pmotacki marked this conversation as resolved.

public AwsIdentityAttestationCreator(AWSAttestationService attestationService) {
this.attestationService = attestationService;
}

@Override
public WorkloadIdentityAttestation createAttestation() {
logger.debug("Creating AWS identity attestation...");
AWSCredentials awsCredentials = attestationService.getAWSCredentials();
if (awsCredentials == null) {
logger.debug("No AWS credentials were found.");
return null;
}
String region = attestationService.getAWSRegion();
if (region == null) {
logger.debug("No AWS region was found.");
return null;
}
String arn = attestationService.getArn();
if (arn == null) {
logger.debug("No Caller Identity was found.");
return null;
}

String stsHostname = String.format("sts.%s.amazonaws.com", region);
Request<Void> request = createStsRequest(stsHostname);
attestationService.signRequestWithSigV4(request, awsCredentials);

String credential = createBase64EncodedRequestCredential(request);
return new WorkloadIdentityAttestation(
WorkloadIdentityProviderType.AWS,
credential,
new HashMap<String, String>() {
Comment thread
sfc-gh-pfus marked this conversation as resolved.
Outdated
{
put("arn", arn);
}
});
}

private Request<Void> createStsRequest(String hostname) {
Request<Void> request = new DefaultRequest<>("sts");
request.setHttpMethod(HttpMethodName.POST);
request.setEndpoint(
Comment thread
sfc-gh-pmotacki marked this conversation as resolved.
URI.create(
String.format("https://%s/?Action=GetCallerIdentity&Version=2011-06-15", hostname)));
request.addHeader("Host", hostname);
request.addHeader(SNOWFLAKE_AUDIENCE_HEADER_NAME, SNOWFLAKE_AUDIENCE);
return request;
}

private String createBase64EncodedRequestCredential(Request<Void> request) {
JsonObject assertionJson = new JsonObject();
JsonObject headers = new JsonObject();
request.getHeaders().forEach(headers::addProperty);
assertionJson.addProperty("url", request.getEndpoint().toString());
assertionJson.addProperty("method", request.getHttpMethod().toString());
assertionJson.add("headers", headers);

String assertionJsonString = assertionJson.toString();
Base64.Encoder encoder = Base64.getEncoder();
Comment thread
sfc-gh-pfus marked this conversation as resolved.
Outdated
return encoder.encodeToString(assertionJsonString.getBytes(StandardCharsets.UTF_8));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package net.snowflake.client.core.auth.wif;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class AzureIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(AzureIdentityAttestationCreator.class);

@Override
public WorkloadIdentityAttestation createAttestation() throws SFException {
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "Azure Workload Identity not supported");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package net.snowflake.client.core.auth.wif;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class GcpIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(GcpIdentityAttestationCreator.class);

@Override
public WorkloadIdentityAttestation createAttestation() throws SFException {
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "GCP Workload Identity not supported");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package net.snowflake.client.core.auth.wif;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class OidcIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(OidcIdentityAttestationCreator.class);

@Override
public WorkloadIdentityAttestation createAttestation() throws SFException {
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "OIDC Workload Identity not supported");
}
}
Loading
Loading