Skip to content

Commit 79f6d07

Browse files
committed
feat: implement private key loading functionality in AsyncConnection and Connection classes
1 parent 6305266 commit 79f6d07

File tree

3 files changed

+78
-64
lines changed

3 files changed

+78
-64
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import base64
2+
from typing import TYPE_CHECKING, Optional
3+
from cryptography.hazmat.primitives.serialization import (
4+
load_der_private_key,
5+
load_pem_private_key,
6+
Encoding,
7+
PrivateFormat,
8+
NoEncryption,
9+
)
10+
11+
if TYPE_CHECKING:
12+
from asn1crypto.keys import RSAPrivateKey
13+
14+
15+
def load_private_key(
16+
data: bytes, private_key_passphrase: Optional[bytes] = None
17+
) -> "RSAPrivateKey":
18+
if data.startswith(b"-----BEGIN "):
19+
return _load_private_key_from_pem(data, private_key_passphrase)
20+
else:
21+
return _load_private_key_from_der(data, private_key_passphrase)
22+
23+
24+
def _load_private_key_from_pem(
25+
data: bytes, private_key_passphrase: Optional[bytes] = None
26+
) -> "RSAPrivateKey":
27+
"""
28+
Load a private key from a PEM encoded byte string.
29+
"""
30+
p_key = load_pem_private_key(data, password=private_key_passphrase)
31+
private_key = p_key.private_bytes(
32+
encoding=Encoding.DER,
33+
format=PrivateFormat.PKCS8,
34+
encryption_algorithm=NoEncryption(),
35+
)
36+
return load_der_private_key(
37+
data=private_key,
38+
password=None,
39+
) # type: ignore[assignment]
40+
41+
42+
def _load_private_key_from_der(
43+
data: bytes, private_key_passphrase: Optional[bytes] = None
44+
) -> "RSAPrivateKey":
45+
"""
46+
Load a private key from a DER encoded byte string.
47+
"""
48+
try:
49+
data = base64.b64decode(data)
50+
except Exception:
51+
pass
52+
53+
return load_der_private_key(
54+
data=data,
55+
password=private_key_passphrase,
56+
) # type: ignore[assignment]

turu-snowflake/src/turu/snowflake/async_connection.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -55,40 +55,19 @@ async def connect( # type: ignore[override]
5555
private_key_passphrase: Union[str, bytes, None] = None,
5656
**kwargs,
5757
) -> Self:
58-
if isinstance(private_key_passphrase, str):
59-
private_key_passphrase = private_key_passphrase.encode("utf-8")
60-
6158
if isinstance(private_key_file, (str, Path)):
62-
from cryptography.hazmat.primitives.serialization import (
63-
load_der_private_key, load_pem_private_key, Encoding, PrivateFormat, NoEncryption
64-
)
6559
with open(private_key_file, "rb") as key:
66-
p_key = load_pem_private_key(
67-
key.read(), password=private_key_passphrase
68-
)
69-
private_key = p_key.private_bytes(
70-
encoding=Encoding.DER,
71-
format=PrivateFormat.PKCS8,
72-
encryption_algorithm=NoEncryption(),
73-
)
74-
private_key = load_der_private_key(
75-
data=private_key,
76-
password=None,
77-
) # type: ignore[assignment]
78-
else:
79-
if isinstance(private_key, str):
80-
private_key = private_key.encode("utf-8")
81-
82-
if isinstance(private_key, bytes):
83-
import base64
84-
from cryptography.hazmat.primitives.serialization import (
85-
load_der_private_key,
86-
)
87-
88-
private_key = load_der_private_key(
89-
data=base64.b64decode(private_key),
90-
password=private_key_passphrase,
91-
) # type: ignore[assignment]
60+
private_key = key.read()
61+
elif isinstance(private_key, str):
62+
private_key = private_key.encode("utf-8")
63+
64+
if isinstance(private_key, bytes):
65+
from turu.snowflake._key_pair import load_private_key
66+
67+
if isinstance(private_key_passphrase, str):
68+
private_key_passphrase = private_key_passphrase.encode("utf-8")
69+
70+
private_key = load_private_key(private_key, private_key_passphrase)
9271

9372
return cls(
9473
snowflake.connector.SnowflakeConnection(

turu-snowflake/src/turu/snowflake/connection.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -67,40 +67,19 @@ def connect( # type: ignore[override]
6767
private_key_passphrase: Union[str, bytes, None] = None,
6868
**kwargs,
6969
) -> Self:
70-
if isinstance(private_key_passphrase, str):
71-
private_key_passphrase = private_key_passphrase.encode("utf-8")
72-
7370
if isinstance(private_key_file, (str, Path)):
74-
from cryptography.hazmat.primitives.serialization import (
75-
load_der_private_key, load_pem_private_key, Encoding, PrivateFormat, NoEncryption
76-
)
7771
with open(private_key_file, "rb") as key:
78-
p_key = load_pem_private_key(
79-
key.read(), password=private_key_passphrase
80-
)
81-
private_key = p_key.private_bytes(
82-
encoding=Encoding.DER,
83-
format=PrivateFormat.PKCS8,
84-
encryption_algorithm=NoEncryption(),
85-
)
86-
private_key = load_der_private_key(
87-
data=private_key,
88-
password=None,
89-
) # type: ignore[assignment]
90-
else:
91-
if isinstance(private_key, str):
92-
private_key = private_key.encode("utf-8")
93-
94-
if isinstance(private_key, bytes):
95-
import base64
96-
from cryptography.hazmat.primitives.serialization import (
97-
load_der_private_key,
98-
)
99-
100-
private_key = load_der_private_key(
101-
data=base64.b64decode(private_key),
102-
password=private_key_passphrase,
103-
) # type: ignore[assignment]
72+
private_key = key.read()
73+
elif isinstance(private_key, str):
74+
private_key = private_key.encode("utf-8")
75+
76+
if isinstance(private_key, bytes):
77+
from turu.snowflake._key_pair import load_private_key
78+
79+
if isinstance(private_key_passphrase, str):
80+
private_key_passphrase = private_key_passphrase.encode("utf-8")
81+
82+
private_key = load_private_key(private_key, private_key_passphrase)
10483

10584
return cls(
10685
snowflake.connector.SnowflakeConnection(

0 commit comments

Comments
 (0)