Skip to content

Commit eeac5f7

Browse files
test: add new int tests for Connector with domain name (#1244)
1 parent fb8c21c commit eeac5f7

File tree

3 files changed

+61
-3
lines changed

3 files changed

+61
-3
lines changed

.github/workflows/tests.yml

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ jobs:
8181
POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS
8282
POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME
8383
POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS
84+
POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME
8485
SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME
8586
SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER
8687
SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS
@@ -102,6 +103,7 @@ jobs:
102103
POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}"
103104
POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}"
104105
POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}"
106+
POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME }}"
105107
SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}"
106108
SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}"
107109
SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}"

tests/system/test_asyncpg_connection.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
import asyncio
1818
import os
19-
from typing import Any
19+
from typing import Any, Union
2020

2121
import asyncpg
2222
import sqlalchemy
2323
import sqlalchemy.ext.asyncio
2424

2525
from google.cloud.sql.connector import Connector
26+
from google.cloud.sql.connector import DefaultResolver
27+
from google.cloud.sql.connector import DnsResolver
2628

2729

2830
async def create_sqlalchemy_engine(
@@ -31,6 +33,7 @@ async def create_sqlalchemy_engine(
3133
password: str,
3234
db: str,
3335
refresh_strategy: str = "background",
36+
resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver,
3437
) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]:
3538
"""Creates a connection pool for a Cloud SQL instance and returns the pool
3639
and the connector. Callers are responsible for closing the pool and the
@@ -64,9 +67,16 @@ async def create_sqlalchemy_engine(
6467
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
6568
or "background". For serverless environments use "lazy" to avoid
6669
errors resulting from CPU being throttled.
70+
resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
71+
Resolver class for resolving instance connection name. Use
72+
google.cloud.sql.connector.DnsResolver when resolving DNS domain
73+
names or google.cloud.sql.connector.DefaultResolver for regular
74+
instance connection names ("my-project:my-region:my-instance").
6775
"""
6876
loop = asyncio.get_running_loop()
69-
connector = Connector(loop=loop, refresh_strategy=refresh_strategy)
77+
connector = Connector(
78+
loop=loop, refresh_strategy=refresh_strategy, resolver=resolver
79+
)
7080

7181
async def getconn() -> asyncpg.Connection:
7282
conn: asyncpg.Connection = await connector.connect_async(
@@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None:
183193
await connector.close_async()
184194

185195

196+
async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg() -> None:
197+
"""Basic test to get time from database."""
198+
inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"]
199+
user = os.environ["POSTGRES_USER"]
200+
password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"]
201+
db = os.environ["POSTGRES_DB"]
202+
203+
pool, connector = await create_sqlalchemy_engine(
204+
inst_conn_name, user, password, db, resolver=DnsResolver
205+
)
206+
207+
async with pool.connect() as conn:
208+
res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone()
209+
assert res[0] == 1
210+
211+
await connector.close_async()
212+
213+
186214
async def test_connection_with_asyncpg() -> None:
187215
"""Basic test to get time from database."""
188216
inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"]

tests/system/test_pg8000_connection.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
import os
1919

2020
# [START cloud_sql_connector_postgres_pg8000]
21+
from typing import Union
22+
2123
import pg8000
2224
import sqlalchemy
2325

2426
from google.cloud.sql.connector import Connector
27+
from google.cloud.sql.connector import DefaultResolver
28+
from google.cloud.sql.connector import DnsResolver
2529

2630

2731
def create_sqlalchemy_engine(
@@ -30,6 +34,7 @@ def create_sqlalchemy_engine(
3034
password: str,
3135
db: str,
3236
refresh_strategy: str = "background",
37+
resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver,
3338
) -> tuple[sqlalchemy.engine.Engine, Connector]:
3439
"""Creates a connection pool for a Cloud SQL instance and returns the pool
3540
and the connector. Callers are responsible for closing the pool and the
@@ -64,8 +69,13 @@ def create_sqlalchemy_engine(
6469
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
6570
or "background". For serverless environments use "lazy" to avoid
6671
errors resulting from CPU being throttled.
72+
resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
73+
Resolver class for resolving instance connection name. Use
74+
google.cloud.sql.connector.DnsResolver when resolving DNS domain
75+
names or google.cloud.sql.connector.DefaultResolver for regular
76+
instance connection names ("my-project:my-region:my-instance").
6777
"""
68-
connector = Connector(refresh_strategy=refresh_strategy)
78+
connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver)
6979

7080
def getconn() -> pg8000.dbapi.Connection:
7181
conn: pg8000.dbapi.Connection = connector.connect(
@@ -153,3 +163,21 @@ def test_customer_managed_CAS_pg8000_connection() -> None:
153163
curr_time = time[0]
154164
assert type(curr_time) is datetime
155165
connector.close()
166+
167+
168+
def test_custom_SAN_with_dns_pg8000_connection() -> None:
169+
"""Basic test to get time from database."""
170+
inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"]
171+
user = os.environ["POSTGRES_USER"]
172+
password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"]
173+
db = os.environ["POSTGRES_DB"]
174+
175+
engine, connector = create_sqlalchemy_engine(
176+
inst_conn_name, user, password, db, resolver=DnsResolver
177+
)
178+
with engine.connect() as conn:
179+
time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone()
180+
conn.commit()
181+
curr_time = time[0]
182+
assert type(curr_time) is datetime
183+
connector.close()

0 commit comments

Comments
 (0)