11import os
22from pathlib import Path
3- from typing import Any , Optional , Sequence , Tuple , Type , Union , cast , overload
3+ from typing import TYPE_CHECKING , Any , Optional , Sequence , Tuple , Type , Union , cast , overload
44
5+ from typing_extensions import Never , Unpack , override
6+
7+ import snowflake .connector
58import turu .core .connection
69import turu .core .cursor
710import turu .core .mock
1316 PanderaDataFrame ,
1417 PyArrowTable ,
1518)
16- from typing_extensions import Never , Unpack , override
17-
18- import snowflake .connector
1919
2020from .cursor import (
2121 Cursor ,
2525 Self ,
2626)
2727
28+ if TYPE_CHECKING :
29+ from asn1crypto .keys import RSAPrivateKey
30+
2831
2932class Connection (turu .core .connection .Connection ):
3033 """
@@ -49,8 +52,25 @@ def connect( # type: ignore[override]
4952 schema : Optional [str ] = None ,
5053 warehouse : Optional [str ] = None ,
5154 role : Optional [str ] = None ,
55+ private_key : "Union[str ,bytes ,RSAPrivateKey, None]" = None ,
56+ private_key_passphrase : Union [str , bytes , None ] = None ,
5257 ** kwargs ,
5358 ) -> Self :
59+ if isinstance (private_key , str ):
60+ private_key = private_key .encode ("utf-8" )
61+
62+ if isinstance (private_key_passphrase , str ):
63+ private_key_passphrase = private_key_passphrase .encode ("utf-8" )
64+
65+ if isinstance (private_key , bytes ) and private_key_passphrase is not None :
66+ import base64
67+ from cryptography .hazmat .primitives .serialization import load_der_private_key
68+
69+ private_key = load_der_private_key (
70+ data = base64 .b64decode (private_key ),
71+ password = private_key_passphrase ,
72+ ) # type: ignore[assignment]
73+
5474 return cls (
5575 snowflake .connector .SnowflakeConnection (
5676 connection_name ,
@@ -62,6 +82,7 @@ def connect( # type: ignore[override]
6282 schema = schema ,
6383 warehouse = warehouse ,
6484 role = role ,
85+ private_key = private_key ,
6586 ** kwargs ,
6687 )
6788 )
@@ -80,6 +101,8 @@ def connect_from_env( # type: ignore[override]
80101 warehouse_envname : str = "SNOWFLAKE_WAREHOUSE" ,
81102 role_envname : str = "SNOWFLAKE_ROLE" ,
82103 authenticator_envname : str = "SNOWFLAKE_AUTHENTICATOR" ,
104+ private_key_envname : str = "SNOWFLAKE_PRIVATE_KEY" ,
105+ private_key_passphrase_envname : str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE" ,
83106 ** kwargs : Any ,
84107 ) -> Self :
85108 if (
@@ -100,6 +123,10 @@ def connect_from_env( # type: ignore[override]
100123 schema = kwargs .get ("schema" , os .environ .get (schema_envname )),
101124 warehouse = kwargs .get ("warehouse" , os .environ .get (warehouse_envname )),
102125 role = kwargs .get ("role" , os .environ .get (role_envname )),
126+ private_key = kwargs .get ("private_key" , os .environ .get (private_key_envname )),
127+ private_key_passphrase = kwargs .get (
128+ "private_key_passphrase" , os .environ .get (private_key_passphrase_envname )
129+ ),
103130 ** kwargs ,
104131 )
105132
0 commit comments