11import os
22from pathlib import Path
3- from typing import Any , Optional , Sequence , Tuple , Type , Union , cast , overload
3+ from typing import (
4+ TYPE_CHECKING ,
5+ Any ,
6+ Optional ,
7+ Sequence ,
8+ Tuple ,
9+ Type ,
10+ Union ,
11+ cast ,
12+ overload ,
13+ )
14+
15+ from typing_extensions import Never , Self , Unpack , override
416
17+ import snowflake .connector
518import turu .core .async_connection
619import turu .core .cursor
720import turu .core .mock
1326 PanderaDataFrame ,
1427 PyArrowTable ,
1528)
16- from typing_extensions import Never , Self , Unpack , override
1729
18- import snowflake .connector
30+ if TYPE_CHECKING :
31+ from asn1crypto .keys import RSAPrivateKey
1932
2033from .async_cursor import AsyncCursor , ExecuteOptions
2134
@@ -37,8 +50,27 @@ async def connect( # type: ignore[override]
3750 schema : Optional [str ] = None ,
3851 warehouse : Optional [str ] = None ,
3952 role : Optional [str ] = None ,
53+ private_key : "Union[str ,bytes ,RSAPrivateKey, None]" = None ,
54+ private_key_passphrase : Union [str , bytes , None ] = None ,
4055 ** kwargs ,
4156 ) -> Self :
57+ if isinstance (private_key , str ):
58+ private_key = private_key .encode ("utf-8" )
59+
60+ if isinstance (private_key_passphrase , str ):
61+ private_key_passphrase = private_key_passphrase .encode ("utf-8" )
62+
63+ if isinstance (private_key , bytes ) and private_key_passphrase is not None :
64+ import base64
65+ from cryptography .hazmat .primitives .serialization import (
66+ load_der_private_key ,
67+ )
68+
69+ private_key = load_der_private_key (
70+ data = base64 .b64decode (private_key ),
71+ password = private_key_passphrase ,
72+ ) # type: ignore[assignment]
73+
4274 return cls (
4375 snowflake .connector .SnowflakeConnection (
4476 connection_name ,
@@ -50,6 +82,7 @@ async def connect( # type: ignore[override]
5082 schema = schema ,
5183 warehouse = warehouse ,
5284 role = role ,
85+ private_key = private_key ,
5386 ** kwargs ,
5487 )
5588 )
@@ -68,6 +101,8 @@ async def connect_from_env( # type: ignore[override]
68101 warehouse_envname : str = "SNOWFLAKE_WAREHOUSE" ,
69102 role_envname : str = "SNOWFLAKE_ROLE" ,
70103 authenticator_envname : str = "SNOWFLAKE_AUTHENTICATOR" ,
104+ private_key_envname : str = "SNOWFLAKE_PRIVATE_KEY" ,
105+ private_key_passphrase_envname : str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE" ,
71106 ** kwargs ,
72107 ) -> Self :
73108 if (
@@ -88,6 +123,10 @@ async def connect_from_env( # type: ignore[override]
88123 schema = kwargs .get ("schema" , os .environ .get (schema_envname )),
89124 warehouse = kwargs .get ("warehouse" , os .environ .get (warehouse_envname )),
90125 role = kwargs .get ("role" , os .environ .get (role_envname )),
126+ private_key = kwargs .pop ("private_key" , os .environ .get (private_key_envname )),
127+ private_key_passphrase = kwargs .pop (
128+ "private_key_passphrase" , os .environ .get (private_key_passphrase_envname )
129+ ),
91130 ** kwargs ,
92131 )
93132
0 commit comments