diff --git a/turu-snowflake/src/turu/snowflake/connection.py b/turu-snowflake/src/turu/snowflake/connection.py index 7fb0e0ef..35d5aee1 100644 --- a/turu-snowflake/src/turu/snowflake/connection.py +++ b/turu-snowflake/src/turu/snowflake/connection.py @@ -1,7 +1,20 @@ import os from pathlib import Path -from typing import Any, Optional, Sequence, Tuple, Type, Union, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, +) + +from typing_extensions import Never, Unpack, override +import snowflake.connector import turu.core.connection import turu.core.cursor import turu.core.mock @@ -13,9 +26,6 @@ PanderaDataFrame, PyArrowTable, ) -from typing_extensions import Never, Unpack, override - -import snowflake.connector from .cursor import ( Cursor, @@ -25,6 +35,9 @@ Self, ) +if TYPE_CHECKING: + from asn1crypto.keys import RSAPrivateKey + class Connection(turu.core.connection.Connection): """ @@ -49,8 +62,27 @@ def connect( # type: ignore[override] schema: Optional[str] = None, warehouse: Optional[str] = None, role: Optional[str] = None, + private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None, + private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: + if isinstance(private_key, str): + private_key = private_key.encode("utf-8") + + if isinstance(private_key_passphrase, str): + private_key_passphrase = private_key_passphrase.encode("utf-8") + + if isinstance(private_key, bytes) and private_key_passphrase is not None: + import base64 + from cryptography.hazmat.primitives.serialization import ( + load_der_private_key, + ) + + private_key = load_der_private_key( + data=base64.b64decode(private_key), + password=private_key_passphrase, + ) # type: ignore[assignment] + return cls( snowflake.connector.SnowflakeConnection( connection_name, @@ -62,6 +94,7 @@ def connect( # type: ignore[override] schema=schema, warehouse=warehouse, role=role, + private_key=private_key, **kwargs, ) ) @@ -80,6 +113,8 @@ def connect_from_env( # type: ignore[override] warehouse_envname: str = "SNOWFLAKE_WAREHOUSE", role_envname: str = "SNOWFLAKE_ROLE", authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR", + private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY", + private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", **kwargs: Any, ) -> Self: if ( @@ -95,11 +130,15 @@ def connect_from_env( # type: ignore[override] connections_file_path, user=kwargs.pop("user", os.environ.get(user_envname)), password=kwargs.pop("password", os.environ.get(password_envname)), - account=kwargs.get("account", os.environ.get(account_envname)), - database=kwargs.get("database", os.environ.get(database_envname)), - schema=kwargs.get("schema", os.environ.get(schema_envname)), - warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)), - role=kwargs.get("role", os.environ.get(role_envname)), + account=kwargs.pop("account", os.environ.get(account_envname)), + database=kwargs.pop("database", os.environ.get(database_envname)), + schema=kwargs.pop("schema", os.environ.get(schema_envname)), + warehouse=kwargs.pop("warehouse", os.environ.get(warehouse_envname)), + role=kwargs.pop("role", os.environ.get(role_envname)), + private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)), + private_key_passphrase=kwargs.pop( + "private_key_passphrase", os.environ.get(private_key_passphrase_envname) + ), **kwargs, )