20
20
from functools import partial
21
21
import logging
22
22
import os
23
+ import socket
23
24
from threading import Thread
24
25
from types import TracebackType
25
- from typing import Any , Optional , Union
26
+ from typing import Any , Callable , Optional , Union
26
27
27
28
import google .auth
28
29
from google .auth .credentials import Credentials
35
36
from google .cloud .sql .connector .enums import RefreshStrategy
36
37
from google .cloud .sql .connector .instance import RefreshAheadCache
37
38
from google .cloud .sql .connector .lazy import LazyRefreshCache
39
+ from google .cloud .sql .connector .monitored_cache import MonitoredCache
38
40
import google .cloud .sql .connector .pg8000 as pg8000
39
41
import google .cloud .sql .connector .pymysql as pymysql
40
42
import google .cloud .sql .connector .pytds as pytds
46
48
logger = logging .getLogger (name = __name__ )
47
49
48
50
ASYNC_DRIVERS = ["asyncpg" ]
51
+ SERVER_PROXY_PORT = 3307
49
52
_DEFAULT_SCHEME = "https://"
50
53
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
51
54
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
@@ -67,6 +70,7 @@ def __init__(
67
70
universe_domain : Optional [str ] = None ,
68
71
refresh_strategy : str | RefreshStrategy = RefreshStrategy .BACKGROUND ,
69
72
resolver : type [DefaultResolver ] | type [DnsResolver ] = DefaultResolver ,
73
+ failover_period : int = 30 ,
70
74
) -> None :
71
75
"""Initializes a Connector instance.
72
76
@@ -114,6 +118,11 @@ def __init__(
114
118
name. To resolve a DNS record to an instance connection name, use
115
119
DnsResolver.
116
120
Default: DefaultResolver
121
+
122
+ failover_period (int): The time interval in seconds between each
123
+ attempt to check if a failover has occured for a given instance.
124
+ Must be used with `resolver=DnsResolver` to have any effect.
125
+ Default: 30
117
126
"""
118
127
# if refresh_strategy is str, convert to RefreshStrategy enum
119
128
if isinstance (refresh_strategy , str ):
@@ -143,9 +152,7 @@ def __init__(
143
152
)
144
153
# initialize dict to store caches, key is a tuple consisting of instance
145
154
# connection name string and enable_iam_auth boolean flag
146
- self ._cache : dict [
147
- tuple [str , bool ], Union [RefreshAheadCache , LazyRefreshCache ]
148
- ] = {}
155
+ self ._cache : dict [tuple [str , bool ], MonitoredCache ] = {}
149
156
self ._client : Optional [CloudSQLClient ] = None
150
157
151
158
# initialize credentials
@@ -167,6 +174,7 @@ def __init__(
167
174
self ._enable_iam_auth = enable_iam_auth
168
175
self ._user_agent = user_agent
169
176
self ._resolver = resolver ()
177
+ self ._failover_period = failover_period
170
178
# if ip_type is str, convert to IPTypes enum
171
179
if isinstance (ip_type , str ):
172
180
ip_type = IPTypes ._from_str (ip_type )
@@ -285,15 +293,19 @@ async def connect_async(
285
293
driver = driver ,
286
294
)
287
295
enable_iam_auth = kwargs .pop ("enable_iam_auth" , self ._enable_iam_auth )
288
- if (instance_connection_string , enable_iam_auth ) in self ._cache :
289
- cache = self ._cache [(instance_connection_string , enable_iam_auth )]
296
+
297
+ conn_name = await self ._resolver .resolve (instance_connection_string )
298
+ # Cache entry must exist and not be closed
299
+ if (str (conn_name ), enable_iam_auth ) in self ._cache and not self ._cache [
300
+ (str (conn_name ), enable_iam_auth )
301
+ ].closed :
302
+ monitored_cache = self ._cache [(str (conn_name ), enable_iam_auth )]
290
303
else :
291
- conn_name = await self ._resolver .resolve (instance_connection_string )
292
304
if self ._refresh_strategy == RefreshStrategy .LAZY :
293
305
logger .debug (
294
306
f"['{ conn_name } ']: Refresh strategy is set to lazy refresh"
295
307
)
296
- cache = LazyRefreshCache (
308
+ cache : Union [ LazyRefreshCache , RefreshAheadCache ] = LazyRefreshCache (
297
309
conn_name ,
298
310
self ._client ,
299
311
self ._keys ,
@@ -309,8 +321,14 @@ async def connect_async(
309
321
self ._keys ,
310
322
enable_iam_auth ,
311
323
)
324
+ # wrap cache as a MonitoredCache
325
+ monitored_cache = MonitoredCache (
326
+ cache ,
327
+ self ._failover_period ,
328
+ self ._resolver ,
329
+ )
312
330
logger .debug (f"['{ conn_name } ']: Connection info added to cache" )
313
- self ._cache [(instance_connection_string , enable_iam_auth )] = cache
331
+ self ._cache [(str ( conn_name ) , enable_iam_auth )] = monitored_cache
314
332
315
333
connect_func = {
316
334
"pymysql" : pymysql .connect ,
@@ -321,7 +339,7 @@ async def connect_async(
321
339
322
340
# only accept supported database drivers
323
341
try :
324
- connector = connect_func [driver ]
342
+ connector : Callable = connect_func [driver ] # type: ignore
325
343
except KeyError :
326
344
raise KeyError (f"Driver '{ driver } ' is not supported." )
327
345
@@ -339,14 +357,14 @@ async def connect_async(
339
357
340
358
# attempt to get connection info for Cloud SQL instance
341
359
try :
342
- conn_info = await cache .connect_info ()
360
+ conn_info = await monitored_cache .connect_info ()
343
361
# validate driver matches intended database engine
344
362
DriverMapping .validate_engine (driver , conn_info .database_version )
345
363
ip_address = conn_info .get_preferred_ip (ip_type )
346
364
except Exception :
347
365
# with an error from Cloud SQL Admin API call or IP type, invalidate
348
366
# the cache and re-raise the error
349
- await self ._remove_cached (instance_connection_string , enable_iam_auth )
367
+ await self ._remove_cached (str ( conn_name ) , enable_iam_auth )
350
368
raise
351
369
logger .debug (f"['{ conn_info .conn_name } ']: Connecting to { ip_address } :3307" )
352
370
# format `user` param for automatic IAM database authn
@@ -367,18 +385,28 @@ async def connect_async(
367
385
await conn_info .create_ssl_context (enable_iam_auth ),
368
386
** kwargs ,
369
387
)
370
- # synchronous drivers are blocking and run using executor
388
+ # Create socket with SSLContext for sync drivers
389
+ ctx = await conn_info .create_ssl_context (enable_iam_auth )
390
+ sock = ctx .wrap_socket (
391
+ socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
392
+ server_hostname = ip_address ,
393
+ )
394
+ # If this connection was opened using a domain name, then store it
395
+ # for later in case we need to forcibly close it on failover.
396
+ if conn_info .conn_name .domain_name :
397
+ monitored_cache .sockets .append (sock )
398
+ # Synchronous drivers are blocking and run using executor
371
399
connect_partial = partial (
372
400
connector ,
373
401
ip_address ,
374
- await conn_info . create_ssl_context ( enable_iam_auth ) ,
402
+ sock ,
375
403
** kwargs ,
376
404
)
377
405
return await self ._loop .run_in_executor (None , connect_partial )
378
406
379
407
except Exception :
380
408
# with any exception, we attempt a force refresh, then throw the error
381
- await cache .force_refresh ()
409
+ await monitored_cache .force_refresh ()
382
410
raise
383
411
384
412
async def _remove_cached (
@@ -456,6 +484,7 @@ async def create_async_connector(
456
484
universe_domain : Optional [str ] = None ,
457
485
refresh_strategy : str | RefreshStrategy = RefreshStrategy .BACKGROUND ,
458
486
resolver : type [DefaultResolver ] | type [DnsResolver ] = DefaultResolver ,
487
+ failover_period : int = 30 ,
459
488
) -> Connector :
460
489
"""Helper function to create Connector object for asyncio connections.
461
490
@@ -507,6 +536,11 @@ async def create_async_connector(
507
536
DnsResolver.
508
537
Default: DefaultResolver
509
538
539
+ failover_period (int): The time interval in seconds between each
540
+ attempt to check if a failover has occured for a given instance.
541
+ Must be used with `resolver=DnsResolver` to have any effect.
542
+ Default: 30
543
+
510
544
Returns:
511
545
A Connector instance configured with running event loop.
512
546
"""
@@ -525,4 +559,5 @@ async def create_async_connector(
525
559
universe_domain = universe_domain ,
526
560
refresh_strategy = refresh_strategy ,
527
561
resolver = resolver ,
562
+ failover_period = failover_period ,
528
563
)
0 commit comments