Skip to content

Commit 1405f56

Browse files
feat: reset connection when the DNS record changes (#1241)
If the connector is configured with a domain name, when that domain name record changes to resolve to a new instance, the connector should detect that change, close all connections to the old instance, and create connections to the new instance.
1 parent dee267f commit 1405f56

27 files changed

+604
-95
lines changed

README.md

+38
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,44 @@ with Connector(resolver=DnsResolver) as connector:
428428
# ... use SQLAlchemy engine normally
429429
```
430430

431+
### Automatic failover using DNS domain names
432+
433+
> [!NOTE]
434+
>
435+
> Usage of the `asyncpg` driver does not currently support automatic failover.
436+
437+
When the connector is configured using a domain name, the connector will
438+
periodically check if the DNS record for an instance changes. When the connector
439+
detects that the domain name refers to a different instance, the connector will
440+
close all open connections to the old instance. Subsequent connection attempts
441+
will be directed to the new instance.
442+
443+
For example: suppose application is configured to connect using the
444+
domain name `prod-db.mycompany.example.com`. Initially the private DNS
445+
zone has a TXT record with the value `my-project:region:my-instance`. The
446+
application establishes connections to the `my-project:region:my-instance`
447+
Cloud SQL instance.
448+
449+
Then, to reconfigure the application to use a different database
450+
instance, change the value of the `prod-db.mycompany.example.com` DNS record
451+
from `my-project:region:my-instance` to `my-project:other-region:my-instance-2`
452+
453+
The connector inside the application detects the change to this
454+
DNS record. Now, when the application connects to its database using the
455+
domain name `prod-db.mycompany.example.com`, it will connect to the
456+
`my-project:other-region:my-instance-2` Cloud SQL instance.
457+
458+
The connector will automatically close all existing connections to
459+
`my-project:region:my-instance`. This will force the connection pools to
460+
establish new connections. Also, it may cause database queries in progress
461+
to fail.
462+
463+
The connector will poll for changes to the DNS name every 30 seconds by default.
464+
You may configure the frequency of the connections using the Connector's
465+
`failover_period` argument (i.e. `Connector(failover_period=60`). When this is
466+
set to 0, the connector will disable polling and only check if the DNS record
467+
changed when it is creating a new connection.
468+
431469
### Using the Python Connector with Python Web Frameworks
432470

433471
The Python Connector can be used alongside popular Python web frameworks such

google/cloud/sql/connector/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
"""
22
Copyright 2019 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");

google/cloud/sql/connector/connection_info.py

+22
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import abc
1718
from dataclasses import dataclass
1819
import logging
1920
import ssl
@@ -34,6 +35,27 @@
3435
logger = logging.getLogger(name=__name__)
3536

3637

38+
class ConnectionInfoCache(abc.ABC):
39+
"""Abstract class for Connector connection info caches."""
40+
41+
@abc.abstractmethod
42+
async def connect_info(self) -> ConnectionInfo:
43+
pass
44+
45+
@abc.abstractmethod
46+
async def force_refresh(self) -> None:
47+
pass
48+
49+
@abc.abstractmethod
50+
async def close(self) -> None:
51+
pass
52+
53+
@property
54+
@abc.abstractmethod
55+
def closed(self) -> bool:
56+
pass
57+
58+
3759
@dataclass
3860
class ConnectionInfo:
3961
"""Contains all necessary information to connect securely to the

google/cloud/sql/connector/connection_name.py

+4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def __str__(self) -> str:
4242
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
4343
return f"{self.project}:{self.region}:{self.instance_name}"
4444

45+
def get_connection_string(self) -> str:
46+
"""Get the instance connection string for the Cloud SQL instance."""
47+
return f"{self.project}:{self.region}:{self.instance_name}"
48+
4549

4650
def _is_valid_domain(domain_name: str) -> bool:
4751
if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None:

google/cloud/sql/connector/connector.py

+50-15
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from functools import partial
2121
import logging
2222
import os
23+
import socket
2324
from threading import Thread
2425
from types import TracebackType
25-
from typing import Any, Optional, Union
26+
from typing import Any, Callable, Optional, Union
2627

2728
import google.auth
2829
from google.auth.credentials import Credentials
@@ -35,6 +36,7 @@
3536
from google.cloud.sql.connector.enums import RefreshStrategy
3637
from google.cloud.sql.connector.instance import RefreshAheadCache
3738
from google.cloud.sql.connector.lazy import LazyRefreshCache
39+
from google.cloud.sql.connector.monitored_cache import MonitoredCache
3840
import google.cloud.sql.connector.pg8000 as pg8000
3941
import google.cloud.sql.connector.pymysql as pymysql
4042
import google.cloud.sql.connector.pytds as pytds
@@ -46,6 +48,7 @@
4648
logger = logging.getLogger(name=__name__)
4749

4850
ASYNC_DRIVERS = ["asyncpg"]
51+
SERVER_PROXY_PORT = 3307
4952
_DEFAULT_SCHEME = "https://"
5053
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
5154
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
@@ -67,6 +70,7 @@ def __init__(
6770
universe_domain: Optional[str] = None,
6871
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
6972
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
73+
failover_period: int = 30,
7074
) -> None:
7175
"""Initializes a Connector instance.
7276
@@ -114,6 +118,11 @@ def __init__(
114118
name. To resolve a DNS record to an instance connection name, use
115119
DnsResolver.
116120
Default: DefaultResolver
121+
122+
failover_period (int): The time interval in seconds between each
123+
attempt to check if a failover has occured for a given instance.
124+
Must be used with `resolver=DnsResolver` to have any effect.
125+
Default: 30
117126
"""
118127
# if refresh_strategy is str, convert to RefreshStrategy enum
119128
if isinstance(refresh_strategy, str):
@@ -143,9 +152,7 @@ def __init__(
143152
)
144153
# initialize dict to store caches, key is a tuple consisting of instance
145154
# connection name string and enable_iam_auth boolean flag
146-
self._cache: dict[
147-
tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache]
148-
] = {}
155+
self._cache: dict[tuple[str, bool], MonitoredCache] = {}
149156
self._client: Optional[CloudSQLClient] = None
150157

151158
# initialize credentials
@@ -167,6 +174,7 @@ def __init__(
167174
self._enable_iam_auth = enable_iam_auth
168175
self._user_agent = user_agent
169176
self._resolver = resolver()
177+
self._failover_period = failover_period
170178
# if ip_type is str, convert to IPTypes enum
171179
if isinstance(ip_type, str):
172180
ip_type = IPTypes._from_str(ip_type)
@@ -285,15 +293,19 @@ async def connect_async(
285293
driver=driver,
286294
)
287295
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
288-
if (instance_connection_string, enable_iam_auth) in self._cache:
289-
cache = self._cache[(instance_connection_string, enable_iam_auth)]
296+
297+
conn_name = await self._resolver.resolve(instance_connection_string)
298+
# Cache entry must exist and not be closed
299+
if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[
300+
(str(conn_name), enable_iam_auth)
301+
].closed:
302+
monitored_cache = self._cache[(str(conn_name), enable_iam_auth)]
290303
else:
291-
conn_name = await self._resolver.resolve(instance_connection_string)
292304
if self._refresh_strategy == RefreshStrategy.LAZY:
293305
logger.debug(
294306
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
295307
)
296-
cache = LazyRefreshCache(
308+
cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache(
297309
conn_name,
298310
self._client,
299311
self._keys,
@@ -309,8 +321,14 @@ async def connect_async(
309321
self._keys,
310322
enable_iam_auth,
311323
)
324+
# wrap cache as a MonitoredCache
325+
monitored_cache = MonitoredCache(
326+
cache,
327+
self._failover_period,
328+
self._resolver,
329+
)
312330
logger.debug(f"['{conn_name}']: Connection info added to cache")
313-
self._cache[(instance_connection_string, enable_iam_auth)] = cache
331+
self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache
314332

315333
connect_func = {
316334
"pymysql": pymysql.connect,
@@ -321,7 +339,7 @@ async def connect_async(
321339

322340
# only accept supported database drivers
323341
try:
324-
connector = connect_func[driver]
342+
connector: Callable = connect_func[driver] # type: ignore
325343
except KeyError:
326344
raise KeyError(f"Driver '{driver}' is not supported.")
327345

@@ -339,14 +357,14 @@ async def connect_async(
339357

340358
# attempt to get connection info for Cloud SQL instance
341359
try:
342-
conn_info = await cache.connect_info()
360+
conn_info = await monitored_cache.connect_info()
343361
# validate driver matches intended database engine
344362
DriverMapping.validate_engine(driver, conn_info.database_version)
345363
ip_address = conn_info.get_preferred_ip(ip_type)
346364
except Exception:
347365
# with an error from Cloud SQL Admin API call or IP type, invalidate
348366
# the cache and re-raise the error
349-
await self._remove_cached(instance_connection_string, enable_iam_auth)
367+
await self._remove_cached(str(conn_name), enable_iam_auth)
350368
raise
351369
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
352370
# format `user` param for automatic IAM database authn
@@ -367,18 +385,28 @@ async def connect_async(
367385
await conn_info.create_ssl_context(enable_iam_auth),
368386
**kwargs,
369387
)
370-
# synchronous drivers are blocking and run using executor
388+
# Create socket with SSLContext for sync drivers
389+
ctx = await conn_info.create_ssl_context(enable_iam_auth)
390+
sock = ctx.wrap_socket(
391+
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
392+
server_hostname=ip_address,
393+
)
394+
# If this connection was opened using a domain name, then store it
395+
# for later in case we need to forcibly close it on failover.
396+
if conn_info.conn_name.domain_name:
397+
monitored_cache.sockets.append(sock)
398+
# Synchronous drivers are blocking and run using executor
371399
connect_partial = partial(
372400
connector,
373401
ip_address,
374-
await conn_info.create_ssl_context(enable_iam_auth),
402+
sock,
375403
**kwargs,
376404
)
377405
return await self._loop.run_in_executor(None, connect_partial)
378406

379407
except Exception:
380408
# with any exception, we attempt a force refresh, then throw the error
381-
await cache.force_refresh()
409+
await monitored_cache.force_refresh()
382410
raise
383411

384412
async def _remove_cached(
@@ -456,6 +484,7 @@ async def create_async_connector(
456484
universe_domain: Optional[str] = None,
457485
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
458486
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
487+
failover_period: int = 30,
459488
) -> Connector:
460489
"""Helper function to create Connector object for asyncio connections.
461490
@@ -507,6 +536,11 @@ async def create_async_connector(
507536
DnsResolver.
508537
Default: DefaultResolver
509538
539+
failover_period (int): The time interval in seconds between each
540+
attempt to check if a failover has occured for a given instance.
541+
Must be used with `resolver=DnsResolver` to have any effect.
542+
Default: 30
543+
510544
Returns:
511545
A Connector instance configured with running event loop.
512546
"""
@@ -525,4 +559,5 @@ async def create_async_connector(
525559
universe_domain=universe_domain,
526560
refresh_strategy=refresh_strategy,
527561
resolver=resolver,
562+
failover_period=failover_period,
528563
)

google/cloud/sql/connector/exceptions.py

+7
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,10 @@ class DnsResolutionError(Exception):
7777
Exception to be raised when an instance connection name can not be resolved
7878
from a DNS record.
7979
"""
80+
81+
82+
class CacheClosedError(Exception):
83+
"""
84+
Exception to be raised when a ConnectionInfoCache can not be accessed after
85+
it is closed.
86+
"""

google/cloud/sql/connector/instance.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from google.cloud.sql.connector.client import CloudSQLClient
2626
from google.cloud.sql.connector.connection_info import ConnectionInfo
27+
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
2728
from google.cloud.sql.connector.connection_name import ConnectionName
2829
from google.cloud.sql.connector.exceptions import RefreshNotValidError
2930
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
@@ -35,7 +36,7 @@
3536
APPLICATION_NAME = "cloud-sql-python-connector"
3637

3738

38-
class RefreshAheadCache:
39+
class RefreshAheadCache(ConnectionInfoCache):
3940
"""Cache that refreshes connection info in the background prior to expiration.
4041
4142
Background tasks are used to schedule refresh attempts to get a new
@@ -74,6 +75,15 @@ def __init__(
7475
self._refresh_in_progress = asyncio.locks.Event()
7576
self._current: asyncio.Task = self._schedule_refresh(0)
7677
self._next: asyncio.Task = self._current
78+
self._closed = False
79+
80+
@property
81+
def conn_name(self) -> ConnectionName:
82+
return self._conn_name
83+
84+
@property
85+
def closed(self) -> bool:
86+
return self._closed
7787

7888
async def force_refresh(self) -> None:
7989
"""
@@ -212,3 +222,4 @@ async def close(self) -> None:
212222
# gracefully wait for tasks to cancel
213223
tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
214224
await asyncio.wait_for(tasks, timeout=2.0)
225+
self._closed = True

google/cloud/sql/connector/lazy.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@
2121

2222
from google.cloud.sql.connector.client import CloudSQLClient
2323
from google.cloud.sql.connector.connection_info import ConnectionInfo
24+
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
2425
from google.cloud.sql.connector.connection_name import ConnectionName
2526
from google.cloud.sql.connector.refresh_utils import _refresh_buffer
2627

2728
logger = logging.getLogger(name=__name__)
2829

2930

30-
class LazyRefreshCache:
31+
class LazyRefreshCache(ConnectionInfoCache):
3132
"""Cache that refreshes connection info when a caller requests a connection.
3233
3334
Only refreshes the cache when a new connection is requested and the current
@@ -62,6 +63,15 @@ def __init__(
6263
self._lock = asyncio.Lock()
6364
self._cached: Optional[ConnectionInfo] = None
6465
self._needs_refresh = False
66+
self._closed = False
67+
68+
@property
69+
def conn_name(self) -> ConnectionName:
70+
return self._conn_name
71+
72+
@property
73+
def closed(self) -> bool:
74+
return self._closed
6575

6676
async def force_refresh(self) -> None:
6777
"""
@@ -121,4 +131,5 @@ async def close(self) -> None:
121131
"""Close is a no-op and provided purely for a consistent interface with
122132
other cache types.
123133
"""
124-
pass
134+
self._closed = True
135+
return

0 commit comments

Comments
 (0)