Skip to content

Commit 8de1691

Browse files
chore: don't use metadata exchange for asyncpg
1 parent d49316b commit 8de1691

File tree

5 files changed

+35
-9
lines changed

5 files changed

+35
-9
lines changed

google/cloud/alloydb/connector/async_connector.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ async def connect(
9696
self._alloydb_api_endpoint,
9797
self._quota_project,
9898
self._credentials,
99+
driver=driver,
99100
)
100101

101102
# use existing connection info if possible

google/cloud/alloydb/connector/client.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
quota_project: Optional[str],
3939
credentials: Credentials,
4040
client: Optional[aiohttp.ClientSession] = None,
41+
driver: Optional[str] = None,
4142
) -> None:
4243
"""
4344
Establish the client to be used for AlloyDB Admin API requests.
@@ -55,10 +56,12 @@ def __init__(
5556
client (aiohttp.ClientSession): Async client used to make requests to
5657
AlloyDB Admin APIs.
5758
Optional, defaults to None and creates new client.
59+
driver (str): Database driver to be used by the client.
5860
"""
61+
user_agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT
5962
headers = {
60-
"x-goog-api-client": USER_AGENT,
61-
"User-Agent": USER_AGENT,
63+
"x-goog-api-client": user_agent,
64+
"User-Agent": user_agent,
6265
"Content-Type": "application/json",
6366
}
6467
if quota_project:
@@ -67,7 +70,7 @@ def __init__(
6770
self._client = client if client else aiohttp.ClientSession(headers=headers)
6871
self._credentials = credentials
6972
self._alloydb_api_endpoint = alloydb_api_endpoint
70-
self._user_agent = USER_AGENT
73+
self._user_agent = user_agent
7174

7275
async def _get_metadata(
7376
self,
@@ -147,10 +150,13 @@ async def _get_client_certificate(
147150

148151
url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"
149152

153+
# asyncpg does not currently support using metadata exchange
154+
# only use metadata exchangefor pg8000 driver
155+
use_metadata = self._user_agent.endswith("pg8000")
150156
data = {
151157
"publicKey": pub_key,
152158
"certDuration": "3600s",
153-
"useMetadataExchange": True,
159+
"useMetadataExchange": use_metadata,
154160
}
155161

156162
resp = await self._client.post(

google/cloud/alloydb/connector/connector.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
131131
if self._client is None:
132132
# lazy init client as it has to be initialized in async context
133133
self._client = AlloyDBClient(
134-
self._alloydb_api_endpoint, self._quota_project, self._credentials
134+
self._alloydb_api_endpoint,
135+
self._quota_project,
136+
self._credentials,
137+
driver=driver,
135138
)
136139
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
137140
# use existing connection info if possible

google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
from typing import ClassVar as _ClassVar
216
from typing import Optional as _Optional
317
from typing import Union as _Union
@@ -14,7 +28,7 @@ class MetadataExchangeRequest(_message.Message):
1428
__slots__ = ["auth_type", "oauth2_token", "user_agent"]
1529

1630
class AuthType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
17-
__slots__ = []
31+
__slots__ = [] # type: ignore
1832
AUTH_TYPE_FIELD_NUMBER: _ClassVar[int]
1933
AUTH_TYPE_UNSPECIFIED: MetadataExchangeRequest.AuthType
2034
AUTO_IAM: MetadataExchangeRequest.AuthType
@@ -35,7 +49,7 @@ class MetadataExchangeResponse(_message.Message):
3549
__slots__ = ["error", "response_code"]
3650

3751
class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
38-
__slots__ = []
52+
__slots__ = [] # type: ignore
3953
ERROR: MetadataExchangeResponse.ResponseCode
4054
ERROR_FIELD_NUMBER: _ClassVar[int]
4155
OK: MetadataExchangeResponse.ResponseCode

tests/unit/mocks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,12 @@ def get_pem_certs(self) -> Tuple[str, str, str]:
155155
class FakeAlloyDBClient:
156156
"""Fake class for testing AlloyDBClient"""
157157

158-
def __init__(self, instance: Optional[FakeInstance] = None) -> None:
158+
def __init__(
159+
self, instance: Optional[FakeInstance] = None, driver: str = "pg8000"
160+
) -> None:
159161
self.instance = FakeInstance() if instance is None else instance
160162
self.closed = False
161-
self._user_agent = "test-user-agent"
163+
self._user_agent = f"test-user-agent+{driver}"
162164

163165
async def _get_metadata(self, *args: Any, **kwargs: Any) -> str:
164166
return self.instance.ip_address

0 commit comments

Comments
 (0)