Skip to content

Commit 7989c34

Browse files
authored
Merge pull request #276 from yassun7010/support_private_key_access
feat: support private_key access.
2 parents d1b19a9 + 0084256 commit 7989c34

File tree

1 file changed

+48
-9
lines changed

1 file changed

+48
-9
lines changed

turu-snowflake/src/turu/snowflake/connection.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Optional, Sequence, Tuple, Type, Union, cast, overload
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Optional,
7+
Sequence,
8+
Tuple,
9+
Type,
10+
Union,
11+
cast,
12+
overload,
13+
)
14+
15+
from typing_extensions import Never, Unpack, override
416

17+
import snowflake.connector
518
import turu.core.connection
619
import turu.core.cursor
720
import turu.core.mock
@@ -13,9 +26,6 @@
1326
PanderaDataFrame,
1427
PyArrowTable,
1528
)
16-
from typing_extensions import Never, Unpack, override
17-
18-
import snowflake.connector
1929

2030
from .cursor import (
2131
Cursor,
@@ -25,6 +35,9 @@
2535
Self,
2636
)
2737

38+
if TYPE_CHECKING:
39+
from asn1crypto.keys import RSAPrivateKey
40+
2841

2942
class Connection(turu.core.connection.Connection):
3043
"""
@@ -49,8 +62,27 @@ def connect( # type: ignore[override]
4962
schema: Optional[str] = None,
5063
warehouse: Optional[str] = None,
5164
role: Optional[str] = None,
65+
private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None,
66+
private_key_passphrase: Union[str, bytes, None] = None,
5267
**kwargs,
5368
) -> Self:
69+
if isinstance(private_key, str):
70+
private_key = private_key.encode("utf-8")
71+
72+
if isinstance(private_key_passphrase, str):
73+
private_key_passphrase = private_key_passphrase.encode("utf-8")
74+
75+
if isinstance(private_key, bytes) and private_key_passphrase is not None:
76+
import base64
77+
from cryptography.hazmat.primitives.serialization import (
78+
load_der_private_key,
79+
)
80+
81+
private_key = load_der_private_key(
82+
data=base64.b64decode(private_key),
83+
password=private_key_passphrase,
84+
) # type: ignore[assignment]
85+
5486
return cls(
5587
snowflake.connector.SnowflakeConnection(
5688
connection_name,
@@ -62,6 +94,7 @@ def connect( # type: ignore[override]
6294
schema=schema,
6395
warehouse=warehouse,
6496
role=role,
97+
private_key=private_key,
6598
**kwargs,
6699
)
67100
)
@@ -80,6 +113,8 @@ def connect_from_env( # type: ignore[override]
80113
warehouse_envname: str = "SNOWFLAKE_WAREHOUSE",
81114
role_envname: str = "SNOWFLAKE_ROLE",
82115
authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR",
116+
private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY",
117+
private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE",
83118
**kwargs: Any,
84119
) -> Self:
85120
if (
@@ -95,11 +130,15 @@ def connect_from_env( # type: ignore[override]
95130
connections_file_path,
96131
user=kwargs.pop("user", os.environ.get(user_envname)),
97132
password=kwargs.pop("password", os.environ.get(password_envname)),
98-
account=kwargs.get("account", os.environ.get(account_envname)),
99-
database=kwargs.get("database", os.environ.get(database_envname)),
100-
schema=kwargs.get("schema", os.environ.get(schema_envname)),
101-
warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)),
102-
role=kwargs.get("role", os.environ.get(role_envname)),
133+
account=kwargs.pop("account", os.environ.get(account_envname)),
134+
database=kwargs.pop("database", os.environ.get(database_envname)),
135+
schema=kwargs.pop("schema", os.environ.get(schema_envname)),
136+
warehouse=kwargs.pop("warehouse", os.environ.get(warehouse_envname)),
137+
role=kwargs.pop("role", os.environ.get(role_envname)),
138+
private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)),
139+
private_key_passphrase=kwargs.pop(
140+
"private_key_passphrase", os.environ.get(private_key_passphrase_envname)
141+
),
103142
**kwargs,
104143
)
105144

0 commit comments

Comments
 (0)