Skip to content

Commit 60c145e

Browse files
chore: add tests
1 parent 50ad0bc commit 60c145e

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

google/cloud/alloydb/connector/client.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def __init__(
7070
self._client = client if client else aiohttp.ClientSession(headers=headers)
7171
self._credentials = credentials
7272
self._alloydb_api_endpoint = alloydb_api_endpoint
73-
self._user_agent = user_agent
73+
# asyncpg does not currently support using metadata exchange
74+
# only use metadata exchange for pg8000 driver
75+
self._use_metadata = True if driver == "pg8000" else False
7476

7577
async def _get_metadata(
7678
self,
@@ -150,13 +152,10 @@ async def _get_client_certificate(
150152

151153
url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"
152154

153-
# asyncpg does not currently support using metadata exchange
154-
# only use metadata exchange for pg8000 driver
155-
use_metadata = self._user_agent.endswith("pg8000")
156155
data = {
157156
"publicKey": pub_key,
158157
"certDuration": "3600s",
159-
"useMetadataExchange": use_metadata,
158+
"useMetadataExchange": self._use_metadata,
160159
}
161160

162161
resp = await self._client.post(

tests/unit/test_client.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16-
from typing import Any
16+
from typing import Any, Optional
1717

1818
from aiohttp import web
1919
from mocks import FakeCredentials
@@ -102,3 +102,52 @@ async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None:
102102
assert client._client.headers["x-goog-user-project"] == "my-quota-project"
103103
# close client
104104
await client.close()
105+
106+
107+
@pytest.mark.parametrize(
108+
"driver",
109+
[None, "pg8000", "asyncpg"],
110+
)
111+
@pytest.mark.asyncio
112+
async def test_AlloyDBClient_user_agent(
113+
driver: Optional[str], credentials: FakeCredentials
114+
) -> None:
115+
"""
116+
Test to check whether the __init__ method of AlloyDBClient
117+
properly sets user agent when passed a database driver.
118+
"""
119+
client = AlloyDBClient(
120+
"www.test-endpoint.com", "my-quota-project", credentials, driver=driver
121+
)
122+
if driver is None:
123+
assert (
124+
client._client.headers["User-Agent"]
125+
== f"alloydb-python-connector/{version}"
126+
)
127+
else:
128+
assert (
129+
client._client.headers["User-Agent"]
130+
== f"alloydb-python-connector/{version}+{driver}"
131+
)
132+
# close client
133+
await client.close()
134+
135+
136+
@pytest.mark.parametrize(
137+
"driver, expected",
138+
[(None, False), ("pg8000", True), ("asyncpg", False)],
139+
)
140+
@pytest.mark.asyncio
141+
async def test_AlloyDBClient_use_metadata(
142+
driver: Optional[str], expected: bool, credentials: FakeCredentials
143+
) -> None:
144+
"""
145+
Test to check whether the __init__ method of AlloyDBClient
146+
properly sets use_metadata.
147+
"""
148+
client = AlloyDBClient(
149+
"www.test-endpoint.com", "my-quota-project", credentials, driver=driver
150+
)
151+
assert client._use_metadata == expected
152+
# close client
153+
await client.close()

0 commit comments

Comments
 (0)