diff --git a/turu-snowflake/src/turu/snowflake/_key_pair.py b/turu-snowflake/src/turu/snowflake/_key_pair.py new file mode 100644 index 0000000..83d23ac --- /dev/null +++ b/turu-snowflake/src/turu/snowflake/_key_pair.py @@ -0,0 +1,56 @@ +import base64 +from typing import TYPE_CHECKING, Optional +from cryptography.hazmat.primitives.serialization import ( + load_der_private_key, + load_pem_private_key, + Encoding, + PrivateFormat, + NoEncryption, +) + +if TYPE_CHECKING: + from asn1crypto.keys import RSAPrivateKey + + +def load_private_key( + data: bytes, private_key_passphrase: Optional[bytes] = None +) -> "RSAPrivateKey": + if data.startswith(b"-----BEGIN "): + return _load_private_key_from_pem(data, private_key_passphrase) + else: + return _load_private_key_from_der(data, private_key_passphrase) + + +def _load_private_key_from_pem( + data: bytes, private_key_passphrase: Optional[bytes] = None +) -> "RSAPrivateKey": + """ + Load a private key from a PEM encoded byte string. + """ + p_key = load_pem_private_key(data, password=private_key_passphrase) + private_key = p_key.private_bytes( + encoding=Encoding.DER, + format=PrivateFormat.PKCS8, + encryption_algorithm=NoEncryption(), + ) + return load_der_private_key( + data=private_key, + password=None, + ) # type: ignore[assignment] + + +def _load_private_key_from_der( + data: bytes, private_key_passphrase: Optional[bytes] = None +) -> "RSAPrivateKey": + """ + Load a private key from a DER encoded byte string. + """ + try: + data = base64.b64decode(data) + except Exception: + pass + + return load_der_private_key( + data=data, + password=private_key_passphrase, + ) # type: ignore[assignment] diff --git a/turu-snowflake/src/turu/snowflake/async_connection.py b/turu-snowflake/src/turu/snowflake/async_connection.py index 7ab3b3b..cfab0e8 100644 --- a/turu-snowflake/src/turu/snowflake/async_connection.py +++ b/turu-snowflake/src/turu/snowflake/async_connection.py @@ -51,25 +51,23 @@ async def connect( # type: ignore[override] warehouse: Optional[str] = None, role: Optional[str] = None, private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None, + private_key_file: Union[str, Path, None] = None, private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: - if isinstance(private_key, str): + if isinstance(private_key_file, (str, Path)): + with open(private_key_file, "rb") as key: + private_key = key.read() + elif isinstance(private_key, str): private_key = private_key.encode("utf-8") - if isinstance(private_key_passphrase, str): - private_key_passphrase = private_key_passphrase.encode("utf-8") + if isinstance(private_key, bytes): + from turu.snowflake._key_pair import load_private_key - if isinstance(private_key, bytes) and private_key_passphrase is not None: - import base64 - from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, - ) + if isinstance(private_key_passphrase, str): + private_key_passphrase = private_key_passphrase.encode("utf-8") - private_key = load_der_private_key( - data=base64.b64decode(private_key), - password=private_key_passphrase, - ) # type: ignore[assignment] + private_key = load_private_key(private_key, private_key_passphrase) return cls( snowflake.connector.SnowflakeConnection( @@ -102,6 +100,7 @@ async def connect_from_env( # type: ignore[override] role_envname: str = "SNOWFLAKE_ROLE", authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR", private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY", + private_key_file_envname: str = "SNOWFLAKE_PRIVATE_KEY_FILE", private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", **kwargs, ) -> Self: @@ -124,6 +123,9 @@ async def connect_from_env( # type: ignore[override] warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)), role=kwargs.get("role", os.environ.get(role_envname)), private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)), + private_key_file=kwargs.pop( + "private_key_file", os.environ.get(private_key_file_envname) + ), private_key_passphrase=kwargs.pop( "private_key_passphrase", os.environ.get(private_key_passphrase_envname) ), diff --git a/turu-snowflake/src/turu/snowflake/connection.py b/turu-snowflake/src/turu/snowflake/connection.py index 35d5aee..442690c 100644 --- a/turu-snowflake/src/turu/snowflake/connection.py +++ b/turu-snowflake/src/turu/snowflake/connection.py @@ -63,25 +63,23 @@ def connect( # type: ignore[override] warehouse: Optional[str] = None, role: Optional[str] = None, private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None, + private_key_file: Union[str, Path, None] = None, private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: - if isinstance(private_key, str): + if isinstance(private_key_file, (str, Path)): + with open(private_key_file, "rb") as key: + private_key = key.read() + elif isinstance(private_key, str): private_key = private_key.encode("utf-8") - if isinstance(private_key_passphrase, str): - private_key_passphrase = private_key_passphrase.encode("utf-8") + if isinstance(private_key, bytes): + from turu.snowflake._key_pair import load_private_key - if isinstance(private_key, bytes) and private_key_passphrase is not None: - import base64 - from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, - ) + if isinstance(private_key_passphrase, str): + private_key_passphrase = private_key_passphrase.encode("utf-8") - private_key = load_der_private_key( - data=base64.b64decode(private_key), - password=private_key_passphrase, - ) # type: ignore[assignment] + private_key = load_private_key(private_key, private_key_passphrase) return cls( snowflake.connector.SnowflakeConnection( @@ -114,6 +112,7 @@ def connect_from_env( # type: ignore[override] role_envname: str = "SNOWFLAKE_ROLE", authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR", private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY", + private_key_file_envname: str = "SNOWFLAKE_PRIVATE_KEY_FILE", private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", **kwargs: Any, ) -> Self: @@ -136,6 +135,9 @@ def connect_from_env( # type: ignore[override] warehouse=kwargs.pop("warehouse", os.environ.get(warehouse_envname)), role=kwargs.pop("role", os.environ.get(role_envname)), private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)), + private_key_file=kwargs.pop( + "private_key_file", os.environ.get(private_key_file_envname) + ), private_key_passphrase=kwargs.pop( "private_key_passphrase", os.environ.get(private_key_passphrase_envname) ),