Skip to content

Commit

Permalink
chore: don't use metadata exchange for asyncpg
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Jan 12, 2024
1 parent d49316b commit 8de1691
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
1 change: 1 addition & 0 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def connect(
self._alloydb_api_endpoint,
self._quota_project,
self._credentials,
driver=driver,
)

# use existing connection info if possible
Expand Down
14 changes: 10 additions & 4 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
quota_project: Optional[str],
credentials: Credentials,
client: Optional[aiohttp.ClientSession] = None,
driver: Optional[str] = None,
) -> None:
"""
Establish the client to be used for AlloyDB Admin API requests.
Expand All @@ -55,10 +56,12 @@ def __init__(
client (aiohttp.ClientSession): Async client used to make requests to
AlloyDB Admin APIs.
Optional, defaults to None and creates new client.
driver (str): Database driver to be used by the client.
"""
user_agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT
headers = {
"x-goog-api-client": USER_AGENT,
"User-Agent": USER_AGENT,
"x-goog-api-client": user_agent,
"User-Agent": user_agent,
"Content-Type": "application/json",
}
if quota_project:
Expand All @@ -67,7 +70,7 @@ def __init__(
self._client = client if client else aiohttp.ClientSession(headers=headers)
self._credentials = credentials
self._alloydb_api_endpoint = alloydb_api_endpoint
self._user_agent = USER_AGENT
self._user_agent = user_agent

async def _get_metadata(
self,
Expand Down Expand Up @@ -147,10 +150,13 @@ async def _get_client_certificate(

url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"

# asyncpg does not currently support using metadata exchange
# only use metadata exchangefor pg8000 driver
use_metadata = self._user_agent.endswith("pg8000")
data = {
"publicKey": pub_key,
"certDuration": "3600s",
"useMetadataExchange": True,
"useMetadataExchange": use_metadata,
}

resp = await self._client.post(
Expand Down
5 changes: 4 additions & 1 deletion google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
if self._client is None:
# lazy init client as it has to be initialized in async context
self._client = AlloyDBClient(
self._alloydb_api_endpoint, self._quota_project, self._credentials
self._alloydb_api_endpoint,
self._quota_project,
self._credentials,
driver=driver,
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
# use existing connection info if possible
Expand Down
18 changes: 16 additions & 2 deletions google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import ClassVar as _ClassVar
from typing import Optional as _Optional
from typing import Union as _Union
Expand All @@ -14,7 +28,7 @@ class MetadataExchangeRequest(_message.Message):
__slots__ = ["auth_type", "oauth2_token", "user_agent"]

class AuthType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
__slots__ = [] # type: ignore
AUTH_TYPE_FIELD_NUMBER: _ClassVar[int]
AUTH_TYPE_UNSPECIFIED: MetadataExchangeRequest.AuthType
AUTO_IAM: MetadataExchangeRequest.AuthType
Expand All @@ -35,7 +49,7 @@ class MetadataExchangeResponse(_message.Message):
__slots__ = ["error", "response_code"]

class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
__slots__ = [] # type: ignore
ERROR: MetadataExchangeResponse.ResponseCode
ERROR_FIELD_NUMBER: _ClassVar[int]
OK: MetadataExchangeResponse.ResponseCode
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,12 @@ def get_pem_certs(self) -> Tuple[str, str, str]:
class FakeAlloyDBClient:
"""Fake class for testing AlloyDBClient"""

def __init__(self, instance: Optional[FakeInstance] = None) -> None:
def __init__(
self, instance: Optional[FakeInstance] = None, driver: str = "pg8000"
) -> None:
self.instance = FakeInstance() if instance is None else instance
self.closed = False
self._user_agent = "test-user-agent"
self._user_agent = f"test-user-agent+{driver}"

async def _get_metadata(self, *args: Any, **kwargs: Any) -> str:
return self.instance.ip_address
Expand Down

0 comments on commit 8de1691

Please sign in to comment.