diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 2eb5283..0db6f3d 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -75,7 +75,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, - static_conn_info: io.TextIOBase = None, + static_conn_info: Optional[io.TextIOBase] = None, ) -> None: self._cache: dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} # initialize default params @@ -147,6 +147,7 @@ async def connect( enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible + cache: Union[RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache] if instance_uri in self._cache: cache = self._cache[instance_uri] elif self._static_conn_info: diff --git a/google/cloud/alloydb/connector/connection_info.py b/google/cloud/alloydb/connector/connection_info.py index 65ca231..aed28ad 100644 --- a/google/cloud/alloydb/connector/connection_info.py +++ b/google/cloud/alloydb/connector/connection_info.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: import datetime - from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from google.cloud.alloydb.connector.enums import IPTypes @@ -41,7 +41,7 @@ class ConnectionInfo: cert_chain: list[str] ca_cert: str - key: rsa.RSAPrivateKey + key: PrivateKeyTypes ip_addrs: dict[str, Optional[str]] expiration: datetime.datetime context: Optional[ssl.SSLContext] = None diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index bcfe2e0..755e2b6 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -87,7 +87,7 @@ def __init__( ip_type: str | IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND, - static_conn_info: io.TextIOBase = None, + static_conn_info: Optional[io.TextIOBase] = None, ) -> None: # create event loop and start it in background thread self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() @@ -176,6 +176,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible + cache: Union[RefreshAheadCache, LazyRefreshCache, StaticConnectionInfoCache] if instance_uri in self._cache: cache = self._cache[instance_uri] elif self._static_conn_info: diff --git a/google/cloud/alloydb/connector/static.py b/google/cloud/alloydb/connector/static.py index 287e29b..da90245 100644 --- a/google/cloud/alloydb/connector/static.py +++ b/google/cloud/alloydb/connector/static.py @@ -19,7 +19,6 @@ import json from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa from google.cloud.alloydb.connector.connection_info import ConnectionInfo @@ -70,9 +69,8 @@ def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None: } expiration = datetime.now(timezone.utc) + timedelta(hours=1) priv_key = static_info["privateKey"] - priv_key_bytes: rsa.RSAPrivateKey = serialization.load_pem_private_key( - priv_key.encode("UTF-8"), - password=None, + priv_key_bytes = serialization.load_pem_private_key( + priv_key.encode("UTF-8"), password=None ) self._info = ConnectionInfo( cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration diff --git a/google/cloud/alloydb/connector/utils.py b/google/cloud/alloydb/connector/utils.py index 908ce40..e4c9939 100644 --- a/google/cloud/alloydb/connector/utils.py +++ b/google/cloud/alloydb/connector/utils.py @@ -17,10 +17,11 @@ import aiofiles from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes async def _write_to_file( - dir_path: str, ca_cert: str, cert_chain: list[str], key: rsa.RSAPrivateKey + dir_path: str, ca_cert: str, cert_chain: list[str], key: PrivateKeyTypes ) -> tuple[str, str, str]: """ Helper function to write the server_ca, client certificate and