Skip to content

Commit c45c13a

Browse files
Enable query parameters parsing in connection URL for keypair auth (#135)
Co-authored-by: Robert <17119716+robmonte@users.noreply.github.com>
1 parent 583752a commit c45c13a

File tree

4 files changed

+161
-81
lines changed

4 files changed

+161
-81
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
kind: BUG FIXES
2+
body: Enable query param parsing on connection URL
3+
time: 2025-09-11T12:14:01.697808-07:00
4+
custom:
5+
Issue: "135"

connection_producer.go

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,42 +197,45 @@ func (c *snowflakeConnectionProducer) Close() error {
197197

198198
// Open the DB connection to Snowflake or return an error.
199199
func openSnowflake(connectionURL, username string, providedPrivateKey []byte) (*sql.DB, error) {
200-
// Parse the connection_url for required fields. Should be of
201-
// the form <account_name>.snowflakecomputing.com/<db_name>
202-
accountName, dbName, err := parseSnowflakeFieldsFromURL(connectionURL)
200+
cfg, err := getSnowflakeConfig(connectionURL, username, providedPrivateKey)
203201
if err != nil {
204-
return nil, err
202+
return nil, fmt.Errorf("error constructing snowflake config: %w", err)
205203
}
204+
connector := gosnowflake.NewConnector(gosnowflake.SnowflakeDriver{}, *cfg)
206205

207-
privateKey, err := getPrivateKey(providedPrivateKey)
206+
return sql.OpenDB(connector), nil
207+
}
208+
209+
func getSnowflakeConfig(connectionURL, username string, providedPrivateKey []byte) (*gosnowflake.Config, error) {
210+
// <account_name>.snowflakecomputing.com/<db_name>?queryParameters...
211+
u, err := url.Parse(connectionURL)
208212
if err != nil {
209-
return nil, err
213+
return nil, fmt.Errorf("error parsing Snowflake connection URL %q: %w", connectionURL, err)
210214
}
211215

212-
snowflakeConfig := &gosnowflake.Config{
213-
Account: accountName,
214-
Database: dbName,
215-
User: username,
216-
Authenticator: gosnowflake.AuthTypeJwt,
217-
PrivateKey: privateKey,
216+
// add authenticator query param to URL to indicate JWT auth
217+
// https://pkg.go.dev/github.com/snowflakedb/gosnowflake#hdr-JWT_authentication
218+
q := u.Query()
219+
q.Set("authenticator", gosnowflake.AuthTypeJwt.String())
220+
//q.Set("privateKey", "true") // This is needed to avoid gosnowflake trying to read the private key from a file path
221+
u.RawQuery = q.Encode()
222+
223+
// construct dsn for gosnowflake
224+
// "user:""@<account_name>.snowflakecomputing.com/<db_name>?queryParameters...
225+
dsn := fmt.Sprintf("%s:%s@%s", username, "", u.String())
226+
cfg, err := gosnowflake.ParseDSN(dsn)
227+
if err != nil {
228+
return nil, fmt.Errorf("error parsing Snowflake DSN %s; err=%w", dsn, err)
218229
}
219-
connector := gosnowflake.NewConnector(gosnowflake.SnowflakeDriver{}, *snowflakeConfig)
220230

221-
return sql.OpenDB(connector), nil
222-
}
223-
224-
// parseSnowflakeFieldsFromURL uses a regex to extract account and DB
225-
// info from a connectionURL
226-
func parseSnowflakeFieldsFromURL(connectionURL string) (string, string, error) {
227-
if !accountAndDBNameFromConnURLRegex.MatchString(connectionURL) {
228-
return "", "", ErrInvalidSnowflakeURL
229-
}
230-
res := accountAndDBNameFromConnURLRegex.FindStringSubmatch(connectionURL)
231-
if len(res) != 3 {
232-
return "", "", ErrInvalidSnowflakeURL
231+
privateKey, err := getPrivateKey(providedPrivateKey)
232+
if err != nil {
233+
return nil, err
233234
}
234235

235-
return res[1], res[2], nil
236+
cfg.PrivateKey = privateKey
237+
238+
return cfg, nil
236239
}
237240

238241
// Open and decode the private key file

connection_producer_test.go

Lines changed: 75 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"crypto/rsa"
1010
"crypto/x509"
1111
"encoding/pem"
12+
"github.com/snowflakedb/gosnowflake"
1213
"testing"
1314

1415
"github.com/stretchr/testify/require"
@@ -41,65 +42,86 @@ func TestOpenSnowflake(t *testing.T) {
4142
require.NotNil(t, db.Stats())
4243
}
4344

44-
// TestParseSnowflakeFieldsFromURL validates that URL
45-
// parsing for keypair authentication works as expected
46-
func TestParseSnowflakeFieldsFromURL(t *testing.T) {
47-
tests := map[string]struct {
48-
connectionURL string
49-
wantAccount string
50-
wantDB string
51-
wantErr error
45+
func TestGetSnowflakeConfig(t *testing.T) {
46+
tt := map[string]struct {
47+
providedPrivateKey string
48+
username string
49+
connectionURL string
50+
expectedConfig *gosnowflake.Config
51+
expectedError string
5252
}{
53-
"valid URL": {
54-
connectionURL: "account.snowflakecomputing.com/db",
55-
wantAccount: "account",
56-
wantDB: "db",
57-
wantErr: nil,
58-
},
59-
"complex URL": {
60-
connectionURL: "dev.org_v2.1.5-us-eas2-1.snowflakecomputing.com/secret-db.name/withslash",
61-
wantAccount: "dev.org_v2.1.5-us-eas2-1",
62-
wantDB: "secret-db.name/withslash",
63-
wantErr: nil,
64-
},
65-
"invalid URL": {
66-
connectionURL: "invalid-url",
67-
wantAccount: "",
68-
wantDB: "",
69-
wantErr: ErrInvalidSnowflakeURL,
70-
},
71-
"missing account name": {
72-
connectionURL: ".snowflakecomputing.com/db",
73-
wantAccount: "",
74-
wantDB: "",
75-
wantErr: ErrInvalidSnowflakeURL,
76-
},
77-
"missing database name": {
78-
connectionURL: "account.snowflakecomputing.com/",
79-
wantAccount: "",
80-
wantDB: "",
81-
wantErr: ErrInvalidSnowflakeURL,
53+
// confirms that the connection URL format upon initial release is correctly parsed
54+
"key pair connection URL format without params": {
55+
providedPrivateKey: testPrivateKey,
56+
username: "testvaultuser",
57+
connectionURL: "testaccount.snowflakecomputing.com/testdb",
58+
expectedConfig: &gosnowflake.Config{
59+
Account: "testaccount",
60+
User: "testvaultuser",
61+
Database: "testdb",
62+
PrivateKey: func() *rsa.PrivateKey {
63+
key, _ := getPrivateKey([]byte(testPrivateKey))
64+
return key
65+
}(),
66+
Authenticator: gosnowflake.AuthTypeJwt,
67+
},
8268
},
83-
"missing domain": {
84-
connectionURL: "account..com/db",
85-
wantAccount: "",
86-
wantDB: "",
87-
wantErr: ErrInvalidSnowflakeURL,
69+
// confirms that query params in the connection URL are correctly parsed
70+
"key pair connection URL format with query params": {
71+
providedPrivateKey: testPrivateKey,
72+
username: "testvaultuser",
73+
connectionURL: "testaccount.snowflakecomputing.com/testdb?disableOCSPChecks=true&maxRetryCount=5",
74+
expectedConfig: &gosnowflake.Config{
75+
Account: "testaccount",
76+
User: "testvaultuser",
77+
Database: "testdb",
78+
PrivateKey: func() *rsa.PrivateKey {
79+
key, _ := getPrivateKey([]byte(testPrivateKey))
80+
return key
81+
}(),
82+
DisableOCSPChecks: true,
83+
MaxRetryCount: 5,
84+
Authenticator: gosnowflake.AuthTypeJwt,
85+
},
8886
},
89-
"escape dots": {
90-
connectionURL: "account.snowflakecomputingXcom/db",
91-
wantAccount: "",
92-
wantDB: "",
93-
wantErr: ErrInvalidSnowflakeURL,
87+
// confirms that DB is optional in the connection URL
88+
"key pair connection URL without DB": {
89+
providedPrivateKey: testPrivateKey,
90+
username: "testvaultuser",
91+
connectionURL: "testaccount.snowflakecomputing.com?disableOCSPChecks=true&maxRetryCount=5",
92+
expectedConfig: &gosnowflake.Config{
93+
Account: "testaccount",
94+
User: "testvaultuser",
95+
PrivateKey: func() *rsa.PrivateKey {
96+
key, _ := getPrivateKey([]byte(testPrivateKey))
97+
return key
98+
}(),
99+
DisableOCSPChecks: true,
100+
MaxRetryCount: 5,
101+
Authenticator: gosnowflake.AuthTypeJwt,
102+
},
94103
},
95104
}
96-
for name, tt := range tests {
97-
t.Run(name, func(t *testing.T) {
98-
user, db, err := parseSnowflakeFieldsFromURL(tt.connectionURL)
99105

100-
require.Equal(t, tt.wantAccount, user)
101-
require.Equal(t, tt.wantDB, db)
102-
require.Equal(t, tt.wantErr, err)
106+
for name, tc := range tt {
107+
t.Run(name, func(t *testing.T) {
108+
cfg, err := getSnowflakeConfig(tc.connectionURL, tc.username, []byte(tc.providedPrivateKey))
109+
if tc.expectedError != "" {
110+
require.Error(t, err)
111+
require.Contains(t, err.Error(), tc.expectedError)
112+
return
113+
}
114+
require.NoError(t, err)
115+
require.NotNil(t, cfg)
116+
// Compare all relevant fields for this test
117+
// this confirms that the config was correctly parsed from the provided inputs
118+
require.Equal(t, tc.expectedConfig.Account, cfg.Account)
119+
require.Equal(t, tc.expectedConfig.User, cfg.User)
120+
require.Equal(t, tc.expectedConfig.Database, cfg.Database)
121+
require.Equal(t, tc.expectedConfig.Authenticator, cfg.Authenticator)
122+
require.Equal(t, tc.expectedConfig.DisableOCSPChecks, cfg.DisableOCSPChecks)
123+
require.Equal(t, tc.expectedConfig.KeepSessionAlive, cfg.KeepSessionAlive)
124+
require.NotNil(t, cfg.PrivateKey)
103125
})
104126
}
105127
}

snowflake_test.go

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,53 @@ func TestSnowflakeSQL_Initialize(t *testing.T) {
9595
db := new()
9696
defer dbtesting.AssertClose(t, db)
9797

98-
connURL, rawBase64PrivateKey, user, err := getKeyPairAuthParameters()
98+
connURL, rawBase64PrivateKey, user, err := getKeyPairAuthParameters("")
99+
if err != nil {
100+
t.Fatalf("failed to retrieve connection URL: %s", err)
101+
}
102+
103+
// decode base64 encoded private key from environment
104+
privateKey, err := base64.StdEncoding.DecodeString(rawBase64PrivateKey)
105+
if err != nil {
106+
t.Fatalf("failed to decode private key: %s", err)
107+
}
108+
109+
expectedConfig := map[string]interface{}{
110+
"connection_url": connURL,
111+
"username": user,
112+
"private_key": privateKey,
113+
dbplugin.SupportedCredentialTypesKey: []interface{}{
114+
dbplugin.CredentialTypePassword.String(),
115+
dbplugin.CredentialTypeRSAPrivateKey.String(),
116+
},
117+
}
118+
req := dbplugin.InitializeRequest{
119+
Config: map[string]interface{}{
120+
"connection_url": connURL,
121+
"username": user,
122+
"private_key": privateKey,
123+
},
124+
VerifyConnection: true,
125+
}
126+
resp := dbtesting.AssertInitialize(t, db, req)
127+
if !reflect.DeepEqual(resp.Config, expectedConfig) {
128+
t.Fatalf("Actual: %#v\nExpected: %#v", resp.Config, expectedConfig)
129+
}
130+
131+
connProducer := db.snowflakeConnectionProducer
132+
if !connProducer.Initialized {
133+
t.Fatal("Database should be initialized")
134+
}
135+
})
136+
137+
// the environment variable SNOWFLAKE_PRIVATE_KEY in CI
138+
// is a base64 encoded string. As such, this test expects the
139+
// input for the variable to be base64 encoded
140+
t.Run("keypair auth with query params", func(t *testing.T) {
141+
db := new()
142+
defer dbtesting.AssertClose(t, db)
143+
144+
connURL, rawBase64PrivateKey, user, err := getKeyPairAuthParameters("disableOCSPChecks=true&maxRetryCount=5")
99145
if err != nil {
100146
t.Fatalf("failed to retrieve connection URL: %s", err)
101147
}
@@ -506,7 +552,7 @@ func dsnString() (string, error) {
506552
return dsnString, nil
507553
}
508554

509-
func getKeyPairAuthParameters() (connURL string, pKey string, user string, err error) {
555+
func getKeyPairAuthParameters(optionalQueryParams string) (connURL string, pKey string, user string, err error) {
510556
user = os.Getenv(envVarSnowflakeUser)
511557
pKey = os.Getenv(envVarSnowflakePrivateKey)
512558
account := os.Getenv(envVarSnowflakeAccount)
@@ -528,6 +574,10 @@ func getKeyPairAuthParameters() (connURL string, pKey string, user string, err e
528574

529575
connURL = fmt.Sprintf("%s.snowflakecomputing.com/%s", user, database)
530576

577+
if optionalQueryParams != "" {
578+
connURL = fmt.Sprintf("%s?%s", connURL, optionalQueryParams)
579+
}
580+
531581
return connURL, pKey, user, err
532582
}
533583

0 commit comments

Comments
 (0)