1
1
from __future__ import annotations
2
2
3
+ from functools import cached_property
3
4
from operator import contains , eq
5
+ from pathlib import Path
4
6
from typing import TYPE_CHECKING , Any , Iterable , Sequence , cast
5
7
6
8
import snowflake .sqlalchemy .custom_types as sct
10
12
from singer_sdk import typing as th
11
13
from singer_sdk .connectors import SQLConnector
12
14
from singer_sdk .connectors .sql import FullyQualifiedName
15
+ from singer_sdk .exceptions import ConfigValidationError
13
16
from snowflake .sqlalchemy import URL
14
17
from snowflake .sqlalchemy .base import SnowflakeIdentifierPreparer
15
18
from snowflake .sqlalchemy .snowdialect import SnowflakeDialect
@@ -124,6 +127,46 @@ def _convert_type(sql_type): # noqa: ANN205, ANN001
124
127
125
128
return sql_type
126
129
130
+ def get_private_key (self ):
131
+ """Get private key from the right location."""
132
+ phrase = self .config .get ("private_key_passphrase" )
133
+ encoded_passphrase = phrase .encode () if phrase else None
134
+ if "private_key_path" in self .config :
135
+ with Path .open (self .config ["private_key_path" ], "rb" ) as key :
136
+ key_content = key .read ()
137
+ else :
138
+ key_content = self .config ["private_key" ].encode ()
139
+
140
+ p_key = serialization .load_pem_private_key (
141
+ key_content ,
142
+ password = encoded_passphrase ,
143
+ backend = default_backend (),
144
+ )
145
+
146
+ return p_key .private_bytes (
147
+ encoding = serialization .Encoding .DER ,
148
+ format = serialization .PrivateFormat .PKCS8 ,
149
+ encryption_algorithm = serialization .NoEncryption (),
150
+ )
151
+
152
+ @cached_property
153
+ def auth_method (self ):
154
+ """Validate & return the authentication method based on config."""
155
+ if self .config .get ("use_browser_authentication" ):
156
+ return "browser_authentication"
157
+
158
+ valid_auth_methods = {"private_key" , "private_key_path" , "password" }
159
+ config_auth_methods = [x for x in self .config if x in valid_auth_methods ]
160
+ if len (config_auth_methods ) == 1 :
161
+ return config_auth_methods [0 ]
162
+
163
+ msg = (
164
+ "Neither password nor private key was provided for "
165
+ "authentication. For password-less browser authentication via SSO, "
166
+ "set use_browser_authentication config option to True."
167
+ )
168
+ raise ConfigValidationError (msg )
169
+
127
170
def get_sqlalchemy_url (self , config : dict ) -> str :
128
171
"""Generates a SQLAlchemy URL for Snowflake.
129
172
@@ -136,17 +179,10 @@ def get_sqlalchemy_url(self, config: dict) -> str:
136
179
"database" : config ["database" ],
137
180
}
138
181
139
- if config . get ( "use_browser_authentication" ) :
182
+ if self . auth_method == "browser_authentication" :
140
183
params ["authenticator" ] = "externalbrowser"
141
- elif "password" in config :
184
+ elif self . auth_method == "password" :
142
185
params ["password" ] = config ["password" ]
143
- elif "private_key_path" not in config :
144
- msg = (
145
- "Neither password nor private_key_path was provided for "
146
- "authentication. For password-less browser authentication via SSO, "
147
- "set use_browser_authentication config option to True."
148
- )
149
- raise Exception (msg ) # noqa: TRY002
150
186
151
187
for option in ["warehouse" , "role" ]:
152
188
if config .get (option ):
@@ -173,20 +209,8 @@ def create_engine(self) -> Engine:
173
209
"QUOTED_IDENTIFIERS_IGNORE_CASE" : "TRUE" ,
174
210
},
175
211
}
176
- if "private_key_path" in self .config :
177
- with open (self .config ["private_key_path" ], "rb" ) as private_key_file : # noqa: PTH123
178
- private_key = serialization .load_pem_private_key (
179
- private_key_file .read (),
180
- password = self .config ["private_key_passphrase" ].encode ()
181
- if "private_key_passphrase" in self .config
182
- else None ,
183
- backend = default_backend (),
184
- )
185
- connect_args ["private_key" ] = private_key .private_bytes (
186
- encoding = serialization .Encoding .DER ,
187
- format = serialization .PrivateFormat .PKCS8 ,
188
- encryption_algorithm = serialization .NoEncryption (),
189
- )
212
+ if self .auth_method in ["private_key" , "private_key_path" ]:
213
+ connect_args ["private_key" ] = self .get_private_key ()
190
214
engine = sqlalchemy .create_engine (
191
215
self .sqlalchemy_url ,
192
216
connect_args = connect_args ,
0 commit comments