From 0928efffebd6154ea20a5a65e11ce3aa0c9ab769 Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Fri, 10 Jan 2025 15:24:15 +0900 Subject: [PATCH 1/3] feat: support private_key access. --- .../src/turu/snowflake/connection.py | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/turu-snowflake/src/turu/snowflake/connection.py b/turu-snowflake/src/turu/snowflake/connection.py index 7fb0e0ef..3b8b57bf 100644 --- a/turu-snowflake/src/turu/snowflake/connection.py +++ b/turu-snowflake/src/turu/snowflake/connection.py @@ -1,7 +1,10 @@ import os from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Type, Union, cast, overload +from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, Union, cast, overload +from typing_extensions import Never, Unpack, override + +import snowflake.connector import turu.core.connection import turu.core.cursor import turu.core.mock @@ -13,9 +16,6 @@ PanderaDataFrame, PyArrowTable, ) -from typing_extensions import Never, Unpack, override - -import snowflake.connector from .cursor import ( Cursor, @@ -25,6 +25,9 @@ Self, ) +if TYPE_CHECKING: + from asn1crypto.keys import RSAPrivateKey + class Connection(turu.core.connection.Connection): """ @@ -49,8 +52,25 @@ def connect( # type: ignore[override] schema: Optional[str] = None, warehouse: Optional[str] = None, role: Optional[str] = None, + private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None, + private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: + if isinstance(private_key, str): + private_key = private_key.encode("utf-8") + + if isinstance(private_key_passphrase, str): + private_key_passphrase = private_key_passphrase.encode("utf-8") + + if isinstance(private_key, bytes) and private_key_passphrase is not None: + import base64 + from cryptography.hazmat.primitives.serialization import load_der_private_key + + private_key = load_der_private_key( + data=base64.b64decode(private_key), + password=private_key_passphrase, + ) # type: ignore[assignment] + return cls( snowflake.connector.SnowflakeConnection( connection_name, @@ -62,6 +82,7 @@ def connect( # type: ignore[override] schema=schema, warehouse=warehouse, role=role, + private_key=private_key, **kwargs, ) ) @@ -80,6 +101,8 @@ def connect_from_env( # type: ignore[override] warehouse_envname: str = "SNOWFLAKE_WAREHOUSE", role_envname: str = "SNOWFLAKE_ROLE", authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR", + private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY", + private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", **kwargs: Any, ) -> Self: if ( @@ -100,6 +123,10 @@ def connect_from_env( # type: ignore[override] schema=kwargs.get("schema", os.environ.get(schema_envname)), warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)), role=kwargs.get("role", os.environ.get(role_envname)), + private_key=kwargs.get("private_key", os.environ.get(private_key_envname)), + private_key_passphrase=kwargs.get( + "private_key_passphrase", os.environ.get(private_key_passphrase_envname) + ), **kwargs, ) From 564b0f0e4f0fc5794c081b9d95aa17a77e16cb34 Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Fri, 10 Jan 2025 15:26:51 +0900 Subject: [PATCH 2/3] fix: ci. --- .../src/turu/snowflake/connection.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/turu-snowflake/src/turu/snowflake/connection.py b/turu-snowflake/src/turu/snowflake/connection.py index 3b8b57bf..77319fa4 100644 --- a/turu-snowflake/src/turu/snowflake/connection.py +++ b/turu-snowflake/src/turu/snowflake/connection.py @@ -1,6 +1,16 @@ import os from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple, Type, Union, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, +) from typing_extensions import Never, Unpack, override @@ -62,14 +72,16 @@ def connect( # type: ignore[override] if isinstance(private_key_passphrase, str): private_key_passphrase = private_key_passphrase.encode("utf-8") - if isinstance(private_key, bytes) and private_key_passphrase is not None: + if isinstance(private_key, bytes) and private_key_passphrase is not None: import base64 - from cryptography.hazmat.primitives.serialization import load_der_private_key + from cryptography.hazmat.primitives.serialization import ( + load_der_private_key, + ) private_key = load_der_private_key( data=base64.b64decode(private_key), password=private_key_passphrase, - ) # type: ignore[assignment] + ) # type: ignore[assignment] return cls( snowflake.connector.SnowflakeConnection( From 13d52856bcd655c636f966e48abc7f576418592c Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Fri, 10 Jan 2025 15:33:24 +0900 Subject: [PATCH 3/3] chore: change: pop. --- turu-snowflake/src/turu/snowflake/connection.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/turu-snowflake/src/turu/snowflake/connection.py b/turu-snowflake/src/turu/snowflake/connection.py index 77319fa4..35d5aee1 100644 --- a/turu-snowflake/src/turu/snowflake/connection.py +++ b/turu-snowflake/src/turu/snowflake/connection.py @@ -130,13 +130,13 @@ def connect_from_env( # type: ignore[override] connections_file_path, user=kwargs.pop("user", os.environ.get(user_envname)), password=kwargs.pop("password", os.environ.get(password_envname)), - account=kwargs.get("account", os.environ.get(account_envname)), - database=kwargs.get("database", os.environ.get(database_envname)), - schema=kwargs.get("schema", os.environ.get(schema_envname)), - warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)), - role=kwargs.get("role", os.environ.get(role_envname)), - private_key=kwargs.get("private_key", os.environ.get(private_key_envname)), - private_key_passphrase=kwargs.get( + account=kwargs.pop("account", os.environ.get(account_envname)), + database=kwargs.pop("database", os.environ.get(database_envname)), + schema=kwargs.pop("schema", os.environ.get(schema_envname)), + warehouse=kwargs.pop("warehouse", os.environ.get(warehouse_envname)), + role=kwargs.pop("role", os.environ.get(role_envname)), + private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)), + private_key_passphrase=kwargs.pop( "private_key_passphrase", os.environ.get(private_key_passphrase_envname) ), **kwargs,