From 6305266da2a7cbacff32324b84cd0c623b0e81af Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Fri, 28 Mar 2025 13:43:21 +0900 Subject: [PATCH 1/2] feat: add support for private key file in AsyncConnection and Connection classes --- .../src/turu/snowflake/async_connection.py | 43 ++++++++++++++----- .../src/turu/snowflake/connection.py | 43 ++++++++++++++----- 2 files changed, 66 insertions(+), 20 deletions(-) diff --git a/turu-snowflake/src/turu/snowflake/async_connection.py b/turu-snowflake/src/turu/snowflake/async_connection.py index 7ab3b3b..112b206 100644 --- a/turu-snowflake/src/turu/snowflake/async_connection.py +++ b/turu-snowflake/src/turu/snowflake/async_connection.py @@ -51,25 +51,44 @@ async def connect( # type: ignore[override] warehouse: Optional[str] = None, role: Optional[str] = None, private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None, + private_key_file: Union[str, Path, None] = None, private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: - if isinstance(private_key, str): - private_key = private_key.encode("utf-8") - if isinstance(private_key_passphrase, str): private_key_passphrase = private_key_passphrase.encode("utf-8") - if isinstance(private_key, bytes) and private_key_passphrase is not None: - import base64 + if isinstance(private_key_file, (str, Path)): from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, + load_der_private_key, load_pem_private_key, Encoding, PrivateFormat, NoEncryption + ) + with open(private_key_file, "rb") as key: + p_key = load_pem_private_key( + key.read(), password=private_key_passphrase + ) + private_key = p_key.private_bytes( + encoding=Encoding.DER, + format=PrivateFormat.PKCS8, + encryption_algorithm=NoEncryption(), ) - private_key = load_der_private_key( - data=base64.b64decode(private_key), - password=private_key_passphrase, - ) # type: ignore[assignment] + data=private_key, + password=None, + ) # type: ignore[assignment] + else: + if isinstance(private_key, str): + private_key = private_key.encode("utf-8") + + if isinstance(private_key, bytes): + import base64 + from cryptography.hazmat.primitives.serialization import ( + load_der_private_key, + ) + + private_key = load_der_private_key( + data=base64.b64decode(private_key), + password=private_key_passphrase, + ) # type: ignore[assignment] return cls( snowflake.connector.SnowflakeConnection( @@ -102,6 +121,7 @@ async def connect_from_env( # type: ignore[override] role_envname: str = "SNOWFLAKE_ROLE", authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR", private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY", + private_key_file_envname: str = "SNOWFLAKE_PRIVATE_KEY_FILE", private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", **kwargs, ) -> Self: @@ -124,6 +144,9 @@ async def connect_from_env( # type: ignore[override] warehouse=kwargs.get("warehouse", os.environ.get(warehouse_envname)), role=kwargs.get("role", os.environ.get(role_envname)), private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)), + private_key_file=kwargs.pop( + "private_key_file", os.environ.get(private_key_file_envname) + ), private_key_passphrase=kwargs.pop( "private_key_passphrase", os.environ.get(private_key_passphrase_envname) ), diff --git a/turu-snowflake/src/turu/snowflake/connection.py b/turu-snowflake/src/turu/snowflake/connection.py index 35d5aee..9457f2b 100644 --- a/turu-snowflake/src/turu/snowflake/connection.py +++ b/turu-snowflake/src/turu/snowflake/connection.py @@ -63,25 +63,44 @@ def connect( # type: ignore[override] warehouse: Optional[str] = None, role: Optional[str] = None, private_key: "Union[str ,bytes ,RSAPrivateKey, None]" = None, + private_key_file: Union[str, Path, None] = None, private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: - if isinstance(private_key, str): - private_key = private_key.encode("utf-8") - if isinstance(private_key_passphrase, str): private_key_passphrase = private_key_passphrase.encode("utf-8") - if isinstance(private_key, bytes) and private_key_passphrase is not None: - import base64 + if isinstance(private_key_file, (str, Path)): from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, + load_der_private_key, load_pem_private_key, Encoding, PrivateFormat, NoEncryption + ) + with open(private_key_file, "rb") as key: + p_key = load_pem_private_key( + key.read(), password=private_key_passphrase + ) + private_key = p_key.private_bytes( + encoding=Encoding.DER, + format=PrivateFormat.PKCS8, + encryption_algorithm=NoEncryption(), ) - private_key = load_der_private_key( - data=base64.b64decode(private_key), - password=private_key_passphrase, - ) # type: ignore[assignment] + data=private_key, + password=None, + ) # type: ignore[assignment] + else: + if isinstance(private_key, str): + private_key = private_key.encode("utf-8") + + if isinstance(private_key, bytes): + import base64 + from cryptography.hazmat.primitives.serialization import ( + load_der_private_key, + ) + + private_key = load_der_private_key( + data=base64.b64decode(private_key), + password=private_key_passphrase, + ) # type: ignore[assignment] return cls( snowflake.connector.SnowflakeConnection( @@ -114,6 +133,7 @@ def connect_from_env( # type: ignore[override] role_envname: str = "SNOWFLAKE_ROLE", authenticator_envname: str = "SNOWFLAKE_AUTHENTICATOR", private_key_envname: str = "SNOWFLAKE_PRIVATE_KEY", + private_key_file_envname: str = "SNOWFLAKE_PRIVATE_KEY_FILE", private_key_passphrase_envname: str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE", **kwargs: Any, ) -> Self: @@ -136,6 +156,9 @@ def connect_from_env( # type: ignore[override] warehouse=kwargs.pop("warehouse", os.environ.get(warehouse_envname)), role=kwargs.pop("role", os.environ.get(role_envname)), private_key=kwargs.pop("private_key", os.environ.get(private_key_envname)), + private_key_file=kwargs.pop( + "private_key_file", os.environ.get(private_key_file_envname) + ), private_key_passphrase=kwargs.pop( "private_key_passphrase", os.environ.get(private_key_passphrase_envname) ), From 79f6d07d290a012a03d8e76d5ca9655a3446d260 Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Fri, 28 Mar 2025 14:07:15 +0900 Subject: [PATCH 2/2] feat: implement private key loading functionality in AsyncConnection and Connection classes --- .../src/turu/snowflake/_key_pair.py | 56 +++++++++++++++++++ .../src/turu/snowflake/async_connection.py | 43 ++++---------- .../src/turu/snowflake/connection.py | 43 ++++---------- 3 files changed, 78 insertions(+), 64 deletions(-) create mode 100644 turu-snowflake/src/turu/snowflake/_key_pair.py diff --git a/turu-snowflake/src/turu/snowflake/_key_pair.py b/turu-snowflake/src/turu/snowflake/_key_pair.py new file mode 100644 index 0000000..83d23ac --- /dev/null +++ b/turu-snowflake/src/turu/snowflake/_key_pair.py @@ -0,0 +1,56 @@ +import base64 +from typing import TYPE_CHECKING, Optional +from cryptography.hazmat.primitives.serialization import ( + load_der_private_key, + load_pem_private_key, + Encoding, + PrivateFormat, + NoEncryption, +) + +if TYPE_CHECKING: + from asn1crypto.keys import RSAPrivateKey + + +def load_private_key( + data: bytes, private_key_passphrase: Optional[bytes] = None +) -> "RSAPrivateKey": + if data.startswith(b"-----BEGIN "): + return _load_private_key_from_pem(data, private_key_passphrase) + else: + return _load_private_key_from_der(data, private_key_passphrase) + + +def _load_private_key_from_pem( + data: bytes, private_key_passphrase: Optional[bytes] = None +) -> "RSAPrivateKey": + """ + Load a private key from a PEM encoded byte string. + """ + p_key = load_pem_private_key(data, password=private_key_passphrase) + private_key = p_key.private_bytes( + encoding=Encoding.DER, + format=PrivateFormat.PKCS8, + encryption_algorithm=NoEncryption(), + ) + return load_der_private_key( + data=private_key, + password=None, + ) # type: ignore[assignment] + + +def _load_private_key_from_der( + data: bytes, private_key_passphrase: Optional[bytes] = None +) -> "RSAPrivateKey": + """ + Load a private key from a DER encoded byte string. + """ + try: + data = base64.b64decode(data) + except Exception: + pass + + return load_der_private_key( + data=data, + password=private_key_passphrase, + ) # type: ignore[assignment] diff --git a/turu-snowflake/src/turu/snowflake/async_connection.py b/turu-snowflake/src/turu/snowflake/async_connection.py index 112b206..cfab0e8 100644 --- a/turu-snowflake/src/turu/snowflake/async_connection.py +++ b/turu-snowflake/src/turu/snowflake/async_connection.py @@ -55,40 +55,19 @@ async def connect( # type: ignore[override] private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: - if isinstance(private_key_passphrase, str): - private_key_passphrase = private_key_passphrase.encode("utf-8") - if isinstance(private_key_file, (str, Path)): - from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, load_pem_private_key, Encoding, PrivateFormat, NoEncryption - ) with open(private_key_file, "rb") as key: - p_key = load_pem_private_key( - key.read(), password=private_key_passphrase - ) - private_key = p_key.private_bytes( - encoding=Encoding.DER, - format=PrivateFormat.PKCS8, - encryption_algorithm=NoEncryption(), - ) - private_key = load_der_private_key( - data=private_key, - password=None, - ) # type: ignore[assignment] - else: - if isinstance(private_key, str): - private_key = private_key.encode("utf-8") - - if isinstance(private_key, bytes): - import base64 - from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, - ) - - private_key = load_der_private_key( - data=base64.b64decode(private_key), - password=private_key_passphrase, - ) # type: ignore[assignment] + private_key = key.read() + elif isinstance(private_key, str): + private_key = private_key.encode("utf-8") + + if isinstance(private_key, bytes): + from turu.snowflake._key_pair import load_private_key + + if isinstance(private_key_passphrase, str): + private_key_passphrase = private_key_passphrase.encode("utf-8") + + private_key = load_private_key(private_key, private_key_passphrase) return cls( snowflake.connector.SnowflakeConnection( diff --git a/turu-snowflake/src/turu/snowflake/connection.py b/turu-snowflake/src/turu/snowflake/connection.py index 9457f2b..442690c 100644 --- a/turu-snowflake/src/turu/snowflake/connection.py +++ b/turu-snowflake/src/turu/snowflake/connection.py @@ -67,40 +67,19 @@ def connect( # type: ignore[override] private_key_passphrase: Union[str, bytes, None] = None, **kwargs, ) -> Self: - if isinstance(private_key_passphrase, str): - private_key_passphrase = private_key_passphrase.encode("utf-8") - if isinstance(private_key_file, (str, Path)): - from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, load_pem_private_key, Encoding, PrivateFormat, NoEncryption - ) with open(private_key_file, "rb") as key: - p_key = load_pem_private_key( - key.read(), password=private_key_passphrase - ) - private_key = p_key.private_bytes( - encoding=Encoding.DER, - format=PrivateFormat.PKCS8, - encryption_algorithm=NoEncryption(), - ) - private_key = load_der_private_key( - data=private_key, - password=None, - ) # type: ignore[assignment] - else: - if isinstance(private_key, str): - private_key = private_key.encode("utf-8") - - if isinstance(private_key, bytes): - import base64 - from cryptography.hazmat.primitives.serialization import ( - load_der_private_key, - ) - - private_key = load_der_private_key( - data=base64.b64decode(private_key), - password=private_key_passphrase, - ) # type: ignore[assignment] + private_key = key.read() + elif isinstance(private_key, str): + private_key = private_key.encode("utf-8") + + if isinstance(private_key, bytes): + from turu.snowflake._key_pair import load_private_key + + if isinstance(private_key_passphrase, str): + private_key_passphrase = private_key_passphrase.encode("utf-8") + + private_key = load_private_key(private_key, private_key_passphrase) return cls( snowflake.connector.SnowflakeConnection(