11import os
22from pathlib import Path
3- from typing import Any , Optional , Sequence , Tuple , Type , Union , cast , overload
3+ from typing import (
4+ TYPE_CHECKING ,
5+ Any ,
6+ Optional ,
7+ Sequence ,
8+ Tuple ,
9+ Type ,
10+ Union ,
11+ cast ,
12+ overload ,
13+ )
14+
15+ from typing_extensions import Never , Unpack , override
416
17+ import snowflake .connector
518import turu .core .connection
619import turu .core .cursor
720import turu .core .mock
1326 PanderaDataFrame ,
1427 PyArrowTable ,
1528)
16- from typing_extensions import Never , Unpack , override
17-
18- import snowflake .connector
1929
2030from .cursor import (
2131 Cursor ,
2535 Self ,
2636)
2737
38+ if TYPE_CHECKING :
39+ from asn1crypto .keys import RSAPrivateKey
40+
2841
2942class Connection (turu .core .connection .Connection ):
3043 """
@@ -49,8 +62,27 @@ def connect( # type: ignore[override]
4962 schema : Optional [str ] = None ,
5063 warehouse : Optional [str ] = None ,
5164 role : Optional [str ] = None ,
65+ private_key : "Union[str ,bytes ,RSAPrivateKey, None]" = None ,
66+ private_key_passphrase : Union [str , bytes , None ] = None ,
5267 ** kwargs ,
5368 ) -> Self :
69+ if isinstance (private_key , str ):
70+ private_key = private_key .encode ("utf-8" )
71+
72+ if isinstance (private_key_passphrase , str ):
73+ private_key_passphrase = private_key_passphrase .encode ("utf-8" )
74+
75+ if isinstance (private_key , bytes ) and private_key_passphrase is not None :
76+ import base64
77+ from cryptography .hazmat .primitives .serialization import (
78+ load_der_private_key ,
79+ )
80+
81+ private_key = load_der_private_key (
82+ data = base64 .b64decode (private_key ),
83+ password = private_key_passphrase ,
84+ ) # type: ignore[assignment]
85+
5486 return cls (
5587 snowflake .connector .SnowflakeConnection (
5688 connection_name ,
@@ -62,6 +94,7 @@ def connect( # type: ignore[override]
6294 schema = schema ,
6395 warehouse = warehouse ,
6496 role = role ,
97+ private_key = private_key ,
6598 ** kwargs ,
6699 )
67100 )
@@ -80,6 +113,8 @@ def connect_from_env( # type: ignore[override]
80113 warehouse_envname : str = "SNOWFLAKE_WAREHOUSE" ,
81114 role_envname : str = "SNOWFLAKE_ROLE" ,
82115 authenticator_envname : str = "SNOWFLAKE_AUTHENTICATOR" ,
116+ private_key_envname : str = "SNOWFLAKE_PRIVATE_KEY" ,
117+ private_key_passphrase_envname : str = "SNOWFLAKE_PRIVATE_KEY_PASSPHRASE" ,
83118 ** kwargs : Any ,
84119 ) -> Self :
85120 if (
@@ -95,11 +130,15 @@ def connect_from_env( # type: ignore[override]
95130 connections_file_path ,
96131 user = kwargs .pop ("user" , os .environ .get (user_envname )),
97132 password = kwargs .pop ("password" , os .environ .get (password_envname )),
98- account = kwargs .get ("account" , os .environ .get (account_envname )),
99- database = kwargs .get ("database" , os .environ .get (database_envname )),
100- schema = kwargs .get ("schema" , os .environ .get (schema_envname )),
101- warehouse = kwargs .get ("warehouse" , os .environ .get (warehouse_envname )),
102- role = kwargs .get ("role" , os .environ .get (role_envname )),
133+ account = kwargs .pop ("account" , os .environ .get (account_envname )),
134+ database = kwargs .pop ("database" , os .environ .get (database_envname )),
135+ schema = kwargs .pop ("schema" , os .environ .get (schema_envname )),
136+ warehouse = kwargs .pop ("warehouse" , os .environ .get (warehouse_envname )),
137+ role = kwargs .pop ("role" , os .environ .get (role_envname )),
138+ private_key = kwargs .pop ("private_key" , os .environ .get (private_key_envname )),
139+ private_key_passphrase = kwargs .pop (
140+ "private_key_passphrase" , os .environ .get (private_key_passphrase_envname )
141+ ),
103142 ** kwargs ,
104143 )
105144
0 commit comments