Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions parent-pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-s3</artifactId>
</dependency>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-sts</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
Expand Down
31 changes: 26 additions & 5 deletions src/main/java/net/snowflake/client/core/SFLoginInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import java.security.PrivateKey;
import java.time.Duration;
import java.util.Map;
import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation;
import net.snowflake.client.jdbc.ErrorCode;
import org.apache.http.client.methods.HttpRequestBase;

/** A class for holding all information required for login */
public class SFLoginInput {
Expand Down Expand Up @@ -57,6 +59,10 @@ public class SFLoginInput {
private boolean enableClientStoreTemporaryCredential;
private boolean enableClientRequestMfaToken;

// Workload Identity Federation
private String workloadIdentityProvider;
private WorkloadIdentityAttestation workloadIdentityAttestation;

// OAuth
private int redirectUriPort = -1;
private String clientId;
Expand Down Expand Up @@ -342,6 +348,15 @@ SFLoginInput setOauthRefreshToken(String oauthRefreshToken) {
return this;
}

String getWorkloadIdentityProvider() {
return workloadIdentityProvider;
}

SFLoginInput setWorkloadIdentityProvider(String workloadIdentityProvider) {
this.workloadIdentityProvider = workloadIdentityProvider;
return this;
}

@SnowflakeJdbcInternalApi
public String getDPoPPublicKey() {
return dpopPublicKey;
Expand Down Expand Up @@ -505,14 +520,11 @@ Map<String, String> getAdditionalHttpHeadersForSnowsight() {
* Set additional http headers to apply to the outgoing request. The additional headers cannot be
* used to replace or overwrite a header in use by the driver. These will be applied to the
* outgoing request. Primarily used by Snowsight, as described in {@link
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)}
* HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map)}
*
* @param additionalHttpHeaders The new headers to add
* @return The input object, for chaining
* @see
* HttpUtil#applyAdditionalHeadersForSnowsight(org.apache.http.client.methods.HttpRequestBase,
* Map)
* @see HttpUtil#applyAdditionalHeadersForSnowsight(HttpRequestBase, Map)
*/
public SFLoginInput setAdditionalHttpHeadersForSnowsight(
Map<String, String> additionalHttpHeaders) {
Expand Down Expand Up @@ -589,4 +601,13 @@ SFLoginInput setOriginalAuthenticator(String originalAuthenticator) {
this.originalAuthenticator = originalAuthenticator;
return this;
}

public void setWorkloadIdentityAttestation(
WorkloadIdentityAttestation workloadIdentityAttestation) {
this.workloadIdentityAttestation = workloadIdentityAttestation;
}

public WorkloadIdentityAttestation getWorkloadIdentityAttestation() {
return workloadIdentityAttestation;
}
}
2 changes: 2 additions & 0 deletions src/main/java/net/snowflake/client/core/SFSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,8 @@ public synchronized void open() throws SFException, SnowflakeSQLException {
.setPrivateKey((PrivateKey) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY))
.setPrivateKeyFile((String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_FILE))
.setOauthLoginInput(oauthLoginInput)
.setWorkloadIdentityProvider(
(String) connectionPropertiesMap.get(SFSessionProperty.WORKLOAD_IDENTITY_PROVIDER))
.setPrivateKeyBase64(
(String) connectionPropertiesMap.get(SFSessionProperty.PRIVATE_KEY_BASE64))
.setPrivateKeyPwd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public enum SFSessionProperty {
OAUTH_SCOPE("oauthScope", false, String.class),
OAUTH_AUTHORIZATION_URL("oauthAuthorizationUrl", false, String.class),
OAUTH_TOKEN_REQUEST_URL("oauthTokenRequestUrl", false, String.class),
WORKLOAD_IDENTITY_PROVIDER("workloadIdentityProvider", false, String.class),
WAREHOUSE("warehouse", false, String.class),
LOGIN_TIMEOUT("loginTimeout", false, Integer.class),
NETWORK_TIMEOUT("networkTimeout", false, Integer.class),
Expand Down
41 changes: 40 additions & 1 deletion src/main/java/net/snowflake/client/core/SessionUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
import net.snowflake.client.core.auth.oauth.OAuthAccessTokenForRefreshTokenProvider;
import net.snowflake.client.core.auth.oauth.OAuthAccessTokenProviderFactory;
import net.snowflake.client.core.auth.oauth.TokenResponseDTO;
import net.snowflake.client.core.auth.wif.AWSAttestationService;
import net.snowflake.client.core.auth.wif.AwsIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.AzureIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.GcpIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.OidcIdentityAttestationCreator;
import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestation;
import net.snowflake.client.core.auth.wif.WorkloadIdentityAttestationProvider;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.RetryContext;
import net.snowflake.client.jdbc.RetryContextManager;
Expand Down Expand Up @@ -242,6 +249,10 @@ private static AuthenticatorType getAuthenticator(SFLoginInput loginInput) {
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN.name())) {
return AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN;
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.WORKLOAD_IDENTITY.name())) {
return AuthenticatorType.WORKLOAD_IDENTITY;
} else if (loginInput
.getAuthenticator()
.equalsIgnoreCase(AuthenticatorType.SNOWFLAKE_JWT.name())) {
Expand Down Expand Up @@ -328,6 +339,24 @@ static SFLoginOutput openSession(
}
}

if (authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) {
WorkloadIdentityAttestationProvider attestationProvider =
new WorkloadIdentityAttestationProvider(
new AwsIdentityAttestationCreator(new AWSAttestationService()),
new GcpIdentityAttestationCreator(),
new AzureIdentityAttestationCreator(),
new OidcIdentityAttestationCreator());
WorkloadIdentityAttestation attestation =
attestationProvider.getAttestation(loginInput.getWorkloadIdentityProvider());
if (attestation != null) {
loginInput.setWorkloadIdentityAttestation(attestation);
} else {
throw new SFException(
ErrorCode.WORKFLOW_IDENTITY_FLOW_ERROR,
"Unable to obtain workload identity attestation. Make sure that correct workload identity provider has been set and that Snowflake-JDBC driver runs on supported environment.");
}
}

convertSessionParameterStringValueToBooleanIfGiven(loginInput, CLIENT_REQUEST_MFA_TOKEN);

readCachedCredentialsIfPossible(loginInput);
Expand All @@ -353,7 +382,8 @@ static SFLoginOutput openSession(
static void checkIfExperimentalAuthnEnabled(AuthenticatorType authenticator) throws SFException {
if (authenticator.equals(AuthenticatorType.PROGRAMMATIC_ACCESS_TOKEN)
|| authenticator.equals(AuthenticatorType.OAUTH_CLIENT_CREDENTIALS)
|| authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)) {
|| authenticator.equals(AuthenticatorType.OAUTH_AUTHORIZATION_CODE)
|| authenticator.equals(AuthenticatorType.WORKLOAD_IDENTITY)) {
boolean experimentalAuthenticationMethodsEnabled =
Boolean.parseBoolean(systemGetEnv("SF_ENABLE_EXPERIMENTAL_AUTHENTICATION"));
AssertUtil.assertTrue(
Expand Down Expand Up @@ -650,6 +680,15 @@ static SFLoginOutput newSession(
data.put(ClientAuthnParameter.OAUTH_TYPE.name(), loginInput.getOriginalAuthenticator());
}

if (authenticatorType == AuthenticatorType.WORKLOAD_IDENTITY) {
data.put(
ClientAuthnParameter.TOKEN.name(),
loginInput.getWorkloadIdentityAttestation().getCredential());
data.put(
ClientAuthnParameter.PROVIDER.name(),
loginInput.getWorkloadIdentityAttestation().getProvider());
}

// map of client environment parameters, including connection parameters
// and environment properties like OS version, etc.
Map<String, Object> clientEnv = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,10 @@ public enum AuthenticatorType {
/*
* Authenticator to support PAT created in Snowflake
*/
PROGRAMMATIC_ACCESS_TOKEN
PROGRAMMATIC_ACCESS_TOKEN,

/*
* Authenticator to support existing authentication by existing AWS/GCP/Azure workload identity
*/
WORKLOAD_IDENTITY
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ public enum ClientAuthnParameter {
SESSION_PARAMETERS,
PROOF_KEY,
TOKEN,
OAUTH_TYPE
OAUTH_TYPE,
PROVIDER
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package net.snowflake.client.core.auth.wif;

import com.amazonaws.SignableRequest;
import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.regions.InstanceMetadataRegionProvider;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest;
import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult;
import java.util.Optional;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.EnvironmentVariables;
import net.snowflake.client.jdbc.SnowflakeUtil;

@SnowflakeJdbcInternalApi
public class AWSAttestationService {

private static final String SECURE_TOKEN_SERVICE_NAME = "sts";
private static boolean regionInitialized = false;
private static String region;

private final AWS4Signer aws4Signer;

public AWSAttestationService() {
aws4Signer = new AWS4Signer();
aws4Signer.setServiceName(SECURE_TOKEN_SERVICE_NAME);
}

AWSCredentials getAWSCredentials() {
return DefaultAWSCredentialsProviderChain.getInstance().getCredentials();
}

String getAWSRegion() {
if (!regionInitialized) {
Comment thread
sfc-gh-dheyman marked this conversation as resolved.
String envRegion = SnowflakeUtil.systemGetEnv(EnvironmentVariables.AWS_REGION.getName());
region = envRegion != null ? envRegion : new InstanceMetadataRegionProvider().getRegion();
Comment thread
sfc-gh-pfus marked this conversation as resolved.
regionInitialized = true;
Comment thread
sfc-gh-pmotacki marked this conversation as resolved.
}
return region;
}

String getArn() {
GetCallerIdentityResult callerIdentity =
AWSSecurityTokenServiceClientBuilder.defaultClient()
.getCallerIdentity(new GetCallerIdentityRequest());
return Optional.ofNullable(callerIdentity).map(GetCallerIdentityResult::getArn).orElse(null);
Comment thread
sfc-gh-pmotacki marked this conversation as resolved.
}

void signRequestWithSigV4(SignableRequest<Void> signableRequest, AWSCredentials awsCredentials) {
aws4Signer.sign(signableRequest, awsCredentials);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package net.snowflake.client.core.auth.wif;

import com.amazonaws.DefaultRequest;
import com.amazonaws.Request;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.http.HttpMethodName;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Collections;
import net.minidev.json.JSONObject;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class AwsIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(AwsIdentityAttestationCreator.class);

private static final String SNOWFLAKE_AUDIENCE_HEADER_NAME = "X-Snowflake-Audience";
private static final String SNOWFLAKE_AUDIENCE = "snowflakecomputing.com";

private final AWSAttestationService attestationService;
Comment thread
sfc-gh-pmotacki marked this conversation as resolved.

public AwsIdentityAttestationCreator(AWSAttestationService attestationService) {
this.attestationService = attestationService;
}

@Override
public WorkloadIdentityAttestation createAttestation() {
logger.debug("Creating AWS identity attestation...");
AWSCredentials awsCredentials = attestationService.getAWSCredentials();
if (awsCredentials == null) {
logger.debug("No AWS credentials were found.");
return null;
}
String region = attestationService.getAWSRegion();
if (region == null) {
logger.debug("No AWS region was found.");
return null;
}
String arn = attestationService.getArn();
if (arn == null) {
logger.debug("No Caller Identity was found.");
return null;
}

String stsHostname = String.format("sts.%s.amazonaws.com", region);
Request<Void> request = createStsRequest(stsHostname);
attestationService.signRequestWithSigV4(request, awsCredentials);

String credential = createBase64EncodedRequestCredential(request);
return new WorkloadIdentityAttestation(
WorkloadIdentityProviderType.AWS, credential, Collections.singletonMap("arn", arn));
}

private Request<Void> createStsRequest(String hostname) {
Request<Void> request = new DefaultRequest<>("sts");
request.setHttpMethod(HttpMethodName.POST);
request.setEndpoint(
Comment thread
sfc-gh-pmotacki marked this conversation as resolved.
URI.create(
String.format("https://%s/?Action=GetCallerIdentity&Version=2011-06-15", hostname)));
request.addHeader("Host", hostname);
request.addHeader(SNOWFLAKE_AUDIENCE_HEADER_NAME, SNOWFLAKE_AUDIENCE);
return request;
}

private String createBase64EncodedRequestCredential(Request<Void> request) {
JSONObject assertionJson = new JSONObject();
JSONObject headers = new JSONObject();
headers.putAll(request.getHeaders());
assertionJson.put("url", request.getEndpoint().toString());
assertionJson.put("method", request.getHttpMethod().toString());
assertionJson.put("headers", headers);

String assertionJsonString = assertionJson.toString();
return Base64.getEncoder().encodeToString(assertionJsonString.getBytes(StandardCharsets.UTF_8));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package net.snowflake.client.core.auth.wif;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class AzureIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(AzureIdentityAttestationCreator.class);

@Override
public WorkloadIdentityAttestation createAttestation() throws SFException {
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "Azure Workload Identity not supported");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package net.snowflake.client.core.auth.wif;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class GcpIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(GcpIdentityAttestationCreator.class);

@Override
public WorkloadIdentityAttestation createAttestation() throws SFException {
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "GCP Workload Identity not supported");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package net.snowflake.client.core.auth.wif;

import net.snowflake.client.core.SFException;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.log.SFLogger;
import net.snowflake.client.log.SFLoggerFactory;

@SnowflakeJdbcInternalApi
public class OidcIdentityAttestationCreator implements WorkloadIdentityAttestationCreator {

private static final SFLogger logger =
SFLoggerFactory.getLogger(OidcIdentityAttestationCreator.class);

@Override
public WorkloadIdentityAttestation createAttestation() throws SFException {
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "OIDC Workload Identity not supported");
}
}
Loading