Skip to content

Commit 6b25f3a

Browse files
authored
Merge pull request #286 from yassun7010/support_snowflake_private_key_async
feat: support private_key to turu.snowflake.AsyncConnection.
2 parents 787eeed + bb43cfd commit 6b25f3a

File tree

1 file changed

+42
-3
lines changed

1 file changed

+42
-3
lines changed

turu-snowflake/src/turu/snowflake/async_connection.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
import os
22
from pathlib import Path
3-
from typing import Any, Optional, Sequence, Tuple, Type, Union, cast, overload
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Optional,
7+
Sequence,
8+
Tuple,
9+
Type,
10+
Union,
11+
cast,
12+
overload,
13+
)
14+
15+
from typing_extensions import Never, Self, Unpack, override
416

17+
import snowflake.connector
518
import turu.core.async_connection
619
import turu.core.cursor
720
import turu.core.mock
@@ -13,9 +26,9 @@
1326
PanderaDataFrame,
1427
PyArrowTable,
1528
)
16-
from typing_extensions import Never, Self, Unpack, override
1729

18-
import snowflake.connector
30+
if TYPE_CHECKING:
31+
from asn1crypto.keys import RSAPrivateKey
1932

2033
from .async_cursor import AsyncCursor, ExecuteOptions
2134

@@ -37,8 +50,27 @@ async def connect( # type: ignore[override]
3750
schema: Optional[str] = None,
3851
warehouse: Optional[str] = None,
3952
role: Optional[str] = None,
53+
private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None,
54+
private_key_passphrase: Union[str, bytes, None] = None,
4055
**kwargs,
4156
) -> Self:
57+
if isinstance(private_key, str):
58+
private_key = private_key.encode("utf-8")
59+
60+
if isinstance(private_key_passphrase, str):
61+
private_key_passphrase = private_key_passphrase.encode("utf-8")
62+
63+
if isinstance(private_key, bytes) and private_key_passphrase is not None:
64+
import base64
65+
from cryptography.hazmat.primitives.serialization import (
66+
load_der_private_key,
67+
)
68+
69+
private_key = load_der_private_key(
70+
data=base64.b64decode(private_key),
71+
password=private_key_passphrase,
72+
) # type: ignore[assignment]
73+
4274
return cls(
4375
snowflake.connector.SnowflakeConnection(
4476
connection_name,
@@ -50,6 +82,7 @@ async def connect( # type: ignore[override]
5082
schema=schema,
5183
warehouse=warehouse,
5284
role=role,
85+
private_key=private_key,
5386
**kwargs,
5487
)
5588
)
@@ -68,6 +101,8 @@ async def connect_from_env( # type: ignore[override]
68101
warehouse_envname: str = "SNOWFLAKE_WAREHOUSE",
69102
role_envname: str = "SNOWFLAKE_ROLE",
70103
authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR",
104+
private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY",
105+
private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE",
71106
**kwargs,
72107
) -> Self:
73108
if (
@@ -88,6 +123,10 @@ async def connect_from_env( # type: ignore[override]
88123
schema=kwargs.get("schema", os.environ.get(schema_envname)),
89124
warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)),
90125
role=kwargs.get("role", os.environ.get(role_envname)),
126+
private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)),
127+
private_key_passphrase=kwargs.pop(
128+
"private_key_passphrase", os.environ.get(private_key_passphrase_envname)
129+
),
91130
**kwargs,
92131
)
93132

0 commit comments

Comments
 (0)