Skip to content

Commit 8866ade

Browse files
committed
feat: add Workload Identity Federation (WIF) authentication support
Adds support for Snowflake Workload Identity Federation as a new authentication method for the admin connection. Supported providers: - AWS (STS/SigV4, no token config needed) - GCP (metadata service, no token config needed) - AZURE (managed identity, optional entra_resource) - OIDC (requires workload_identity_token) New config fields: workload_identity_provider - required, one of AWS/GCP/AZURE/OIDC workload_identity_token - required for OIDC provider only workload_identity_entra_resource - optional, Azure only WIF is mutually exclusive with password and private_key auth. Adds 21 unit tests covering config construction, connection setup, and validation across all provider types.
1 parent 44b77de commit 8866ade

File tree

2 files changed

+390
-14
lines changed

2 files changed

+390
-14
lines changed

connection_producer.go

Lines changed: 116 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"fmt"
1313
"net/url"
1414
"regexp"
15+
"strings"
1516
"sync"
1617
"time"
1718

@@ -25,22 +26,38 @@ import (
2526
"github.com/snowflakedb/gosnowflake"
2627
)
2728

29+
const (
30+
wifProviderAWS = "AWS"
31+
wifProviderGCP = "GCP"
32+
wifProviderAzure = "AZURE"
33+
wifProviderOIDC = "OIDC"
34+
)
35+
36+
var validWIFProviders = []string{wifProviderAWS, wifProviderGCP, wifProviderAzure, wifProviderOIDC}
37+
2838
var (
2939
ErrInvalidSnowflakeURL = fmt.Errorf("invalid connection URL format, expect <account_name>.snowflakecomputing.com/<db_name>")
3040
ErrInvalidPrivateKey = fmt.Errorf("failed to read provided private_key")
41+
ErrWIFMutuallyExclusive = fmt.Errorf("workload_identity_provider cannot be combined with password or private_key authentication")
42+
ErrWIFTokenRequired = fmt.Errorf("workload_identity_token is required when using the OIDC workload identity provider")
43+
ErrWIFUsernameRequired = fmt.Errorf("username is required for workload identity federation")
44+
ErrWIFInvalidProvider = fmt.Errorf("workload_identity_provider must be one of: AWS, GCP, AZURE, OIDC")
3145
accountAndDBNameFromConnURLRegex = regexp.MustCompile(`^(.+)\.snowflakecomputing\.com/(.+)$`) // Expected format: <account_name>.snowflakecomputing.com/<db_name>
3246
)
3347

3448
type snowflakeConnectionProducer struct {
35-
ConnectionURL string `json:"connection_url"`
36-
MaxOpenConnections int `json:"max_open_connections"`
37-
MaxIdleConnections int `json:"max_idle_connections"`
38-
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime"`
39-
Username string `json:"username"`
40-
Password string `json:"password"`
41-
PrivateKey []byte `json:"private_key"`
42-
UsernameTemplate string `json:"username_template"`
43-
DisableEscaping bool `json:"disable_escaping"`
49+
ConnectionURL string `json:"connection_url"`
50+
MaxOpenConnections int `json:"max_open_connections"`
51+
MaxIdleConnections int `json:"max_idle_connections"`
52+
MaxConnectionLifetimeRaw interface{} `json:"max_connection_lifetime"`
53+
Username string `json:"username"`
54+
Password string `json:"password"`
55+
PrivateKey []byte `json:"private_key"`
56+
WorkloadIdentityProvider string `json:"workload_identity_provider"`
57+
WorkloadIdentityToken string `json:"workload_identity_token"`
58+
WorkloadIdentityEntraResource string `json:"workload_identity_entra_resource"`
59+
UsernameTemplate string `json:"username_template"`
60+
DisableEscaping bool `json:"disable_escaping"`
4461

4562
Initialized bool
4663
RawConfig map[string]any
@@ -53,8 +70,9 @@ type snowflakeConnectionProducer struct {
5370

5471
func (c *snowflakeConnectionProducer) secretValues() map[string]string {
5572
return map[string]string{
56-
c.Password: "[password]",
57-
string(c.PrivateKey): "[private_key]",
73+
c.Password: "[password]",
74+
string(c.PrivateKey): "[private_key]",
75+
c.WorkloadIdentityToken: "[workload_identity_token]",
5876
}
5977
}
6078

@@ -86,7 +104,11 @@ func (c *snowflakeConnectionProducer) Init(ctx context.Context, initConfig map[s
86104
return nil, fmt.Errorf("connection_url cannot be empty")
87105
}
88106

89-
if len(c.Password) > 0 {
107+
if len(c.WorkloadIdentityProvider) > 0 {
108+
if err := c.validateWIFConfig(); err != nil {
109+
return nil, err
110+
}
111+
} else if len(c.Password) > 0 {
90112
// Return an error here once Snowflake ends support for password auth.
91113
c.logger.Warn("[DEPRECATED] Single-factor password authentication is deprecated in Snowflake and will be removed by November 2025. " +
92114
"Key pair authentication will be required after this date.")
@@ -137,6 +159,36 @@ func (c *snowflakeConnectionProducer) Init(ctx context.Context, initConfig map[s
137159
return initConfig, nil
138160
}
139161

162+
// validateWIFConfig checks that the WIF configuration is valid and mutually
163+
// exclusive with other authentication methods.
164+
func (c *snowflakeConnectionProducer) validateWIFConfig() error {
165+
if len(c.Password) > 0 || len(c.PrivateKey) > 0 {
166+
return ErrWIFMutuallyExclusive
167+
}
168+
169+
if c.Username == "" {
170+
return ErrWIFUsernameRequired
171+
}
172+
173+
provider := strings.ToUpper(c.WorkloadIdentityProvider)
174+
valid := false
175+
for _, p := range validWIFProviders {
176+
if provider == p {
177+
valid = true
178+
break
179+
}
180+
}
181+
if !valid {
182+
return ErrWIFInvalidProvider
183+
}
184+
185+
if provider == wifProviderOIDC && c.WorkloadIdentityToken == "" {
186+
return ErrWIFTokenRequired
187+
}
188+
189+
return nil
190+
}
191+
140192
func (c *snowflakeConnectionProducer) Initialize(ctx context.Context, config map[string]any, verifyConnection bool) error {
141193
_, err := c.Init(ctx, config, verifyConnection)
142194
return err
@@ -162,12 +214,18 @@ func (c *snowflakeConnectionProducer) Connection(ctx context.Context) (interface
162214

163215
var db *sql.DB
164216
var err error
165-
if len(c.PrivateKey) > 0 {
217+
switch {
218+
case len(c.WorkloadIdentityProvider) > 0:
219+
db, err = openSnowflakeWIF(c.ConnectionURL, c.Username, c.WorkloadIdentityProvider, c.WorkloadIdentityToken, c.WorkloadIdentityEntraResource)
220+
if err != nil {
221+
return nil, fmt.Errorf("error opening Snowflake connection using workload identity federation: %w", err)
222+
}
223+
case len(c.PrivateKey) > 0:
166224
db, err = openSnowflake(c.ConnectionURL, c.Username, c.PrivateKey)
167225
if err != nil {
168226
return nil, fmt.Errorf("error opening Snowflake connection using key-pair auth: %w", err)
169227
}
170-
} else {
228+
default:
171229
db, err = sql.Open(snowflakeSQLTypeName, c.ConnectionURL)
172230
if err != nil {
173231
return nil, fmt.Errorf("error opening Snowflake connection using user-pass auth: %w", err)
@@ -202,6 +260,50 @@ func (c *snowflakeConnectionProducer) Close() error {
202260
return c.close()
203261
}
204262

263+
// openSnowflakeWIF opens a Snowflake connection using Workload Identity Federation.
264+
// provider must be one of: AWS, GCP, AZURE, OIDC.
265+
// token is required only for the OIDC provider.
266+
// entraResource is optional and only applies to Azure environments.
267+
func openSnowflakeWIF(connectionURL, username, provider, token, entraResource string) (*sql.DB, error) {
268+
cfg, err := getSnowflakeWIFConfig(connectionURL, username, provider, token, entraResource)
269+
if err != nil {
270+
return nil, fmt.Errorf("error constructing snowflake WIF config: %w", err)
271+
}
272+
connector := gosnowflake.NewConnector(gosnowflake.SnowflakeDriver{}, *cfg)
273+
return sql.OpenDB(connector), nil
274+
}
275+
276+
// getSnowflakeWIFConfig builds a gosnowflake.Config for Workload Identity Federation auth.
277+
func getSnowflakeWIFConfig(connectionURL, username, provider, token, entraResource string) (*gosnowflake.Config, error) {
278+
u, err := url.Parse(connectionURL)
279+
if err != nil {
280+
return nil, fmt.Errorf("error parsing Snowflake connection URL %q: %w", connectionURL, err)
281+
}
282+
283+
q := u.Query()
284+
q.Set("authenticator", gosnowflake.AuthTypeWorkloadIdentityFederation.String())
285+
u.RawQuery = q.Encode()
286+
287+
// construct dsn: "user:@<account>.snowflakecomputing.com/<db>?..."
288+
dsn := fmt.Sprintf("%s:%s@%s", username, "", u.String())
289+
cfg, err := gosnowflake.ParseDSN(dsn)
290+
if err != nil {
291+
return nil, fmt.Errorf("error parsing Snowflake DSN: %w", err)
292+
}
293+
294+
cfg.WorkloadIdentityProvider = strings.ToUpper(provider)
295+
296+
if token != "" {
297+
cfg.Token = token
298+
}
299+
300+
if entraResource != "" {
301+
cfg.WorkloadIdentityEntraResource = entraResource
302+
}
303+
304+
return cfg, nil
305+
}
306+
205307
// Open the DB connection to Snowflake or return an error.
206308
func openSnowflake(connectionURL, username string, providedPrivateKey []byte) (*sql.DB, error) {
207309
cfg, err := getSnowflakeConfig(connectionURL, username, providedPrivateKey)

0 commit comments

Comments
 (0)