Skip to content

Commit 0928eff

Browse files
committed
feat: support private_key access.
1 parent 2c38fa0 commit 0928eff

File tree

1 file changed

+31
-4
lines changed

1 file changed

+31
-4
lines changed

turu-snowflake/src/turu/snowflake/connection.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Optional, Sequence, Tuple, Type, Union, cast, overload
3+
from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, Union, cast, overload
44

5+
from typing_extensions import Never, Unpack, override
6+
7+
import snowflake.connector
58
import turu.core.connection
69
import turu.core.cursor
710
import turu.core.mock
@@ -13,9 +16,6 @@
1316
PanderaDataFrame,
1417
PyArrowTable,
1518
)
16-
from typing_extensions import Never, Unpack, override
17-
18-
import snowflake.connector
1919

2020
from .cursor import (
2121
Cursor,
@@ -25,6 +25,9 @@
2525
Self,
2626
)
2727

28+
if TYPE_CHECKING:
29+
from asn1crypto.keys import RSAPrivateKey
30+
2831

2932
class Connection(turu.core.connection.Connection):
3033
"""
@@ -49,8 +52,25 @@ def connect( # type: ignore[override]
4952
schema: Optional[str] = None,
5053
warehouse: Optional[str] = None,
5154
role: Optional[str] = None,
55+
private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None,
56+
private_key_passphrase: Union[str, bytes, None] = None,
5257
**kwargs,
5358
) -> Self:
59+
if isinstance(private_key, str):
60+
private_key = private_key.encode("utf-8")
61+
62+
if isinstance(private_key_passphrase, str):
63+
private_key_passphrase = private_key_passphrase.encode("utf-8")
64+
65+
if isinstance(private_key, bytes) and private_key_passphrase is not None:
66+
import base64
67+
from cryptography.hazmat.primitives.serialization import load_der_private_key
68+
69+
private_key = load_der_private_key(
70+
data=base64.b64decode(private_key),
71+
password=private_key_passphrase,
72+
) # type: ignore[assignment]
73+
5474
return cls(
5575
snowflake.connector.SnowflakeConnection(
5676
connection_name,
@@ -62,6 +82,7 @@ def connect( # type: ignore[override]
6282
schema=schema,
6383
warehouse=warehouse,
6484
role=role,
85+
private_key=private_key,
6586
**kwargs,
6687
)
6788
)
@@ -80,6 +101,8 @@ def connect_from_env( # type: ignore[override]
80101
warehouse_envname: str = "SNOWFLAKE_WAREHOUSE",
81102
role_envname: str = "SNOWFLAKE_ROLE",
82103
authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR",
104+
private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY",
105+
private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE",
83106
**kwargs: Any,
84107
) -> Self:
85108
if (
@@ -100,6 +123,10 @@ def connect_from_env( # type: ignore[override]
100123
schema=kwargs.get("schema", os.environ.get(schema_envname)),
101124
warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)),
102125
role=kwargs.get("role", os.environ.get(role_envname)),
126+
private_key=kwargs.get("private_key", os.environ.get(private_key_envname)),
127+
private_key_passphrase=kwargs.get(
128+
"private_key_passphrase", os.environ.get(private_key_passphrase_envname)
129+
),
103130
**kwargs,
104131
)
105132

0 commit comments

Comments
 (0)