16
16
17
17
import asyncio
18
18
import os
19
- from typing import Any
19
+ from typing import Any , Union
20
20
21
21
import asyncpg
22
22
import sqlalchemy
23
23
import sqlalchemy .ext .asyncio
24
24
25
25
from google .cloud .sql .connector import Connector
26
+ from google .cloud .sql .connector import DefaultResolver
27
+ from google .cloud .sql .connector import DnsResolver
26
28
27
29
28
30
async def create_sqlalchemy_engine (
@@ -31,6 +33,7 @@ async def create_sqlalchemy_engine(
31
33
password : str ,
32
34
db : str ,
33
35
refresh_strategy : str = "background" ,
36
+ resolver : Union [type [DefaultResolver ], type [DnsResolver ]] = DefaultResolver ,
34
37
) -> tuple [sqlalchemy .ext .asyncio .engine .AsyncEngine , Connector ]:
35
38
"""Creates a connection pool for a Cloud SQL instance and returns the pool
36
39
and the connector. Callers are responsible for closing the pool and the
@@ -64,9 +67,16 @@ async def create_sqlalchemy_engine(
64
67
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
65
68
or "background". For serverless environments use "lazy" to avoid
66
69
errors resulting from CPU being throttled.
70
+ resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
71
+ Resolver class for resolving instance connection name. Use
72
+ google.cloud.sql.connector.DnsResolver when resolving DNS domain
73
+ names or google.cloud.sql.connector.DefaultResolver for regular
74
+ instance connection names ("my-project:my-region:my-instance").
67
75
"""
68
76
loop = asyncio .get_running_loop ()
69
- connector = Connector (loop = loop , refresh_strategy = refresh_strategy )
77
+ connector = Connector (
78
+ loop = loop , refresh_strategy = refresh_strategy , resolver = resolver
79
+ )
70
80
71
81
async def getconn () -> asyncpg .Connection :
72
82
conn : asyncpg .Connection = await connector .connect_async (
@@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None:
183
193
await connector .close_async ()
184
194
185
195
196
+ async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg () -> None :
197
+ """Basic test to get time from database."""
198
+ inst_conn_name = os .environ ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME" ]
199
+ user = os .environ ["POSTGRES_USER" ]
200
+ password = os .environ ["POSTGRES_CUSTOMER_CAS_PASS" ]
201
+ db = os .environ ["POSTGRES_DB" ]
202
+
203
+ pool , connector = await create_sqlalchemy_engine (
204
+ inst_conn_name , user , password , db , resolver = DnsResolver
205
+ )
206
+
207
+ async with pool .connect () as conn :
208
+ res = (await conn .execute (sqlalchemy .text ("SELECT 1" ))).fetchone ()
209
+ assert res [0 ] == 1
210
+
211
+ await connector .close_async ()
212
+
213
+
186
214
async def test_connection_with_asyncpg () -> None :
187
215
"""Basic test to get time from database."""
188
216
inst_conn_name = os .environ ["POSTGRES_CONNECTION_NAME" ]
0 commit comments