Skip to content

Commit 946fb2b

Browse files
committed
add private_key config parameter
1 parent a7b8699 commit 946fb2b

File tree

2 files changed

+61
-24
lines changed

2 files changed

+61
-24
lines changed

Diff for: target_snowflake/connector.py

+47-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
from functools import cached_property
34
from operator import contains, eq
5+
from pathlib import Path
46
from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast
57

68
import snowflake.sqlalchemy.custom_types as sct
@@ -10,6 +12,7 @@
1012
from singer_sdk import typing as th
1113
from singer_sdk.connectors import SQLConnector
1214
from singer_sdk.connectors.sql import FullyQualifiedName
15+
from singer_sdk.exceptions import ConfigValidationError
1316
from snowflake.sqlalchemy import URL
1417
from snowflake.sqlalchemy.base import SnowflakeIdentifierPreparer
1518
from snowflake.sqlalchemy.snowdialect import SnowflakeDialect
@@ -124,6 +127,46 @@ def _convert_type(sql_type): # noqa: ANN205, ANN001
124127

125128
return sql_type
126129

130+
def get_private_key(self):
131+
"""Get private key from the right location."""
132+
phrase = self.config.get("private_key_passphrase")
133+
encoded_passphrase = phrase.encode() if phrase else None
134+
if "private_key_path" in self.config:
135+
with Path.open(self.config["private_key_path"], "rb") as key:
136+
key_content = key.read()
137+
else:
138+
key_content = self.config["private_key"].encode()
139+
140+
p_key = serialization.load_pem_private_key(
141+
key_content,
142+
password=encoded_passphrase,
143+
backend=default_backend(),
144+
)
145+
146+
return p_key.private_bytes(
147+
encoding=serialization.Encoding.DER,
148+
format=serialization.PrivateFormat.PKCS8,
149+
encryption_algorithm=serialization.NoEncryption(),
150+
)
151+
152+
@cached_property
153+
def auth_method(self):
154+
"""Validate & return the authentication method based on config."""
155+
if self.config.get("use_browser_authentication"):
156+
return "browser_authentication"
157+
158+
valid_auth_methods = {"private_key", "private_key_path", "password"}
159+
config_auth_methods = [x for x in self.config if x in valid_auth_methods]
160+
if len(config_auth_methods) == 1:
161+
return config_auth_methods[0]
162+
163+
msg = (
164+
"Neither password nor private key was provided for "
165+
"authentication. For password-less browser authentication via SSO, "
166+
"set use_browser_authentication config option to True."
167+
)
168+
raise ConfigValidationError(msg)
169+
127170
def get_sqlalchemy_url(self, config: dict) -> str:
128171
"""Generates a SQLAlchemy URL for Snowflake.
129172
@@ -136,17 +179,10 @@ def get_sqlalchemy_url(self, config: dict) -> str:
136179
"database": config["database"],
137180
}
138181

139-
if config.get("use_browser_authentication"):
182+
if self.auth_method == "browser_authentication":
140183
params["authenticator"] = "externalbrowser"
141-
elif "password" in config:
184+
elif self.auth_method == "password":
142185
params["password"] = config["password"]
143-
elif "private_key_path" not in config:
144-
msg = (
145-
"Neither password nor private_key_path was provided for "
146-
"authentication. For password-less browser authentication via SSO, "
147-
"set use_browser_authentication config option to True."
148-
)
149-
raise Exception(msg) # noqa: TRY002
150186

151187
for option in ["warehouse", "role"]:
152188
if config.get(option):
@@ -173,20 +209,8 @@ def create_engine(self) -> Engine:
173209
"QUOTED_IDENTIFIERS_IGNORE_CASE": "TRUE",
174210
},
175211
}
176-
if "private_key_path" in self.config:
177-
with open(self.config["private_key_path"], "rb") as private_key_file: # noqa: PTH123
178-
private_key = serialization.load_pem_private_key(
179-
private_key_file.read(),
180-
password=self.config["private_key_passphrase"].encode()
181-
if "private_key_passphrase" in self.config
182-
else None,
183-
backend=default_backend(),
184-
)
185-
connect_args["private_key"] = private_key.private_bytes(
186-
encoding=serialization.Encoding.DER,
187-
format=serialization.PrivateFormat.PKCS8,
188-
encryption_algorithm=serialization.NoEncryption(),
189-
)
212+
if self.auth_method in ["private_key", "private_key_path"]:
213+
connect_args["private_key"] = self.get_private_key()
190214
engine = sqlalchemy.create_engine(
191215
self.sqlalchemy_url,
192216
connect_args=connect_args,

Diff for: target_snowflake/target.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,24 @@ class TargetSnowflake(SQLTarget):
3030
required=False,
3131
description="The password for your Snowflake user.",
3232
),
33+
th.Property(
34+
"private_key",
35+
th.StringType,
36+
required=False,
37+
secret=True,
38+
description=(
39+
"The private key contents. For KeyPair authentication either "
40+
"private_key or private_key_path must be provided."
41+
),
42+
),
3343
th.Property(
3444
"private_key_path",
3545
th.StringType,
3646
required=False,
37-
description="Path to file containing private key.",
47+
description=(
48+
"Path to file containing private key. For KeyPair authentication either "
49+
"private_key or private_key_path must be provided."
50+
),
3851
),
3952
th.Property(
4053
"private_key_passphrase",

0 commit comments

Comments
 (0)