Skip to content

Commit 0063e1f

Browse files
authored
Merge pull request #293 from yassun7010/support_private_key_file
feat: add support for private key file in AsyncConnection and Connection classes
2 parents 9782eb5 + 5945865 commit 0063e1f

File tree

3 files changed

+84
-24
lines changed

3 files changed

+84
-24
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import base64
2+
from typing import TYPE_CHECKING, Optional
3+
from cryptography.hazmat.primitives.serialization import (
4+
load_der_private_key,
5+
load_pem_private_key,
6+
Encoding,
7+
PrivateFormat,
8+
NoEncryption,
9+
)
10+
11+
if TYPE_CHECKING:
12+
from asn1crypto.keys import RSAPrivateKey
13+
14+
15+
def load_private_key(
16+
data: bytes, private_key_passphrase: Optional[bytes] = None
17+
) -> "RSAPrivateKey":
18+
if data.startswith(b"-----BEGIN "):
19+
return _load_private_key_from_pem(data, private_key_passphrase)
20+
else:
21+
return _load_private_key_from_der(data, private_key_passphrase)
22+
23+
24+
def _load_private_key_from_pem(
25+
data: bytes, private_key_passphrase: Optional[bytes] = None
26+
) -> "RSAPrivateKey":
27+
"""
28+
Load a private key from a PEM encoded byte string.
29+
"""
30+
p_key = load_pem_private_key(data, password=private_key_passphrase)
31+
private_key = p_key.private_bytes(
32+
encoding=Encoding.DER,
33+
format=PrivateFormat.PKCS8,
34+
encryption_algorithm=NoEncryption(),
35+
)
36+
return load_der_private_key(
37+
data=private_key,
38+
password=None,
39+
) # type: ignore[assignment]
40+
41+
42+
def _load_private_key_from_der(
43+
data: bytes, private_key_passphrase: Optional[bytes] = None
44+
) -> "RSAPrivateKey":
45+
"""
46+
Load a private key from a DER encoded byte string.
47+
"""
48+
try:
49+
data = base64.b64decode(data)
50+
except Exception:
51+
pass
52+
53+
return load_der_private_key(
54+
data=data,
55+
password=private_key_passphrase,
56+
) # type: ignore[assignment]

turu-snowflake/src/turu/snowflake/async_connection.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,23 @@ async def connect( # type: ignore[override]
5151
warehouse: Optional[str] = None,
5252
role: Optional[str] = None,
5353
private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None,
54+
private_key_file: Union[str, Path, None] = None,
5455
private_key_passphrase: Union[str, bytes, None] = None,
5556
**kwargs,
5657
) -> Self:
57-
if isinstance(private_key, str):
58+
if isinstance(private_key_file, (str, Path)):
59+
with open(private_key_file, "rb") as key:
60+
private_key = key.read()
61+
elif isinstance(private_key, str):
5862
private_key = private_key.encode("utf-8")
5963

60-
if isinstance(private_key_passphrase, str):
61-
private_key_passphrase = private_key_passphrase.encode("utf-8")
64+
if isinstance(private_key, bytes):
65+
from turu.snowflake._key_pair import load_private_key
6266

63-
if isinstance(private_key, bytes) and private_key_passphrase is not None:
64-
import base64
65-
from cryptography.hazmat.primitives.serialization import (
66-
load_der_private_key,
67-
)
67+
if isinstance(private_key_passphrase, str):
68+
private_key_passphrase = private_key_passphrase.encode("utf-8")
6869

69-
private_key = load_der_private_key(
70-
data=base64.b64decode(private_key),
71-
password=private_key_passphrase,
72-
) # type: ignore[assignment]
70+
private_key = load_private_key(private_key, private_key_passphrase)
7371

7472
return cls(
7573
snowflake.connector.SnowflakeConnection(
@@ -102,6 +100,7 @@ async def connect_from_env( # type: ignore[override]
102100
role_envname: str = "SNOWFLAKE_ROLE",
103101
authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR",
104102
private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY",
103+
private_key_file_envname: str = "SNOWFLAKE_PRIVATE_KEY_FILE",
105104
private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE",
106105
**kwargs,
107106
) -> Self:
@@ -124,6 +123,9 @@ async def connect_from_env( # type: ignore[override]
124123
warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)),
125124
role=kwargs.get("role", os.environ.get(role_envname)),
126125
private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)),
126+
private_key_file=kwargs.pop(
127+
"private_key_file", os.environ.get(private_key_file_envname)
128+
),
127129
private_key_passphrase=kwargs.pop(
128130
"private_key_passphrase", os.environ.get(private_key_passphrase_envname)
129131
),

turu-snowflake/src/turu/snowflake/connection.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,23 @@ def connect( # type: ignore[override]
6363
warehouse: Optional[str] = None,
6464
role: Optional[str] = None,
6565
private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None,
66+
private_key_file: Union[str, Path, None] = None,
6667
private_key_passphrase: Union[str, bytes, None] = None,
6768
**kwargs,
6869
) -> Self:
69-
if isinstance(private_key, str):
70+
if isinstance(private_key_file, (str, Path)):
71+
with open(private_key_file, "rb") as key:
72+
private_key = key.read()
73+
elif isinstance(private_key, str):
7074
private_key = private_key.encode("utf-8")
7175

72-
if isinstance(private_key_passphrase, str):
73-
private_key_passphrase = private_key_passphrase.encode("utf-8")
76+
if isinstance(private_key, bytes):
77+
from turu.snowflake._key_pair import load_private_key
7478

75-
if isinstance(private_key, bytes) and private_key_passphrase is not None:
76-
import base64
77-
from cryptography.hazmat.primitives.serialization import (
78-
load_der_private_key,
79-
)
79+
if isinstance(private_key_passphrase, str):
80+
private_key_passphrase = private_key_passphrase.encode("utf-8")
8081

81-
private_key = load_der_private_key(
82-
data=base64.b64decode(private_key),
83-
password=private_key_passphrase,
84-
) # type: ignore[assignment]
82+
private_key = load_private_key(private_key, private_key_passphrase)
8583

8684
return cls(
8785
snowflake.connector.SnowflakeConnection(
@@ -114,6 +112,7 @@ def connect_from_env( # type: ignore[override]
114112
role_envname: str = "SNOWFLAKE_ROLE",
115113
authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR",
116114
private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY",
115+
private_key_file_envname: str = "SNOWFLAKE_PRIVATE_KEY_FILE",
117116
private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE",
118117
**kwargs: Any,
119118
) -> Self:
@@ -136,6 +135,9 @@ def connect_from_env( # type: ignore[override]
136135
warehouse=kwargs.pop("warehouse", os.environ.get(warehouse_envname)),
137136
role=kwargs.pop("role", os.environ.get(role_envname)),
138137
private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)),
138+
private_key_file=kwargs.pop(
139+
"private_key_file", os.environ.get(private_key_file_envname)
140+
),
139141
private_key_passphrase=kwargs.pop(
140142
"private_key_passphrase", os.environ.get(private_key_passphrase_envname)
141143
),

0 commit comments

Comments
 (0)