Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support ip_type as str #267

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

[flake8]
ignore = E203, E231, E266, E501, W503, ANN101, ANN401
ignore = E203, E231, E266, E501, W503, ANN101, ANN102, ANN401
exclude =
# Exclude generated code.
**/proto/**
Expand Down
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,12 +381,11 @@ to your instance's private IP. To change this, such as connecting to AlloyDB
over a public IP address, set the `ip_type` keyword argument when initializing
a `Connector()` or when calling `connector.connect()`.

Possible values for `ip_type` are `IPTypes.PRIVATE` (default value), and
`IPTypes.PUBLIC`.
Possible values for `ip_type` are `"PRIVATE"` (default value), and `"PUBLIC"`.
Example:

```python
from google.cloud.alloydb.connector import Connector, IPTypes
from google.cloud.alloydb.connector import Connector

import sqlalchemy

Expand All @@ -401,7 +400,7 @@ def getconn():
user="my-user",
password="my-password",
db="my-db-name",
ip_type=IPTypes.PUBLIC, # use public IP
ip_type="PUBLIC", # use public IP
)

# create connection pool
Expand Down
14 changes: 10 additions & 4 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class AsyncConnector:
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE for private IP connections.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
"""

def __init__(
Expand All @@ -57,14 +57,17 @@ def __init__(
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: IPTypes = IPTypes.PRIVATE,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
) -> None:
self._instances: Dict[str, Instance] = {}
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type)
self._ip_type = ip_type
self._user_agent = user_agent
# initialize credentials
Expand Down Expand Up @@ -144,7 +147,10 @@ async def connect(
kwargs.pop("port", None)

# get connection info for AlloyDB instance
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
ip_type: str | IPTypes = kwargs.pop("ip_type", self._ip_type)
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type)
ip_address, context = await instance.connection_info(ip_type)

# callable to be used for auto IAM authn
Expand Down
14 changes: 10 additions & 4 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class Connector:
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
ip_type (IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE for private IP connections.
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
"""

def __init__(
Expand All @@ -67,7 +67,7 @@ def __init__(
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
ip_type: IPTypes = IPTypes.PRIVATE,
ip_type: str | IPTypes = IPTypes.PRIVATE,
user_agent: Optional[str] = None,
) -> None:
# create event loop and start it in background thread
Expand All @@ -79,6 +79,9 @@ def __init__(
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type)
self._ip_type = ip_type
self._user_agent = user_agent
# initialize credentials
Expand Down Expand Up @@ -171,7 +174,10 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
kwargs.pop("port", None)

# get connection info for AlloyDB instance
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
ip_type: IPTypes | str = kwargs.pop("ip_type", self._ip_type)
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type)
ip_address, context = await instance.connection_info(ip_type)

# synchronous drivers are blocking and run using executor
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/alloydb/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class IPTypes(Enum):
PUBLIC: str = "PUBLIC"
PRIVATE: str = "PRIVATE"

@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(
f"Incorrect value for ip_type, got '{value}'. Want one of: "
f"{', '.join([repr(m.value) for m in cls])}."
)


def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]:
# should take form "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>"
Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_asyncpg_public_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import sqlalchemy.ext.asyncio

from google.cloud.alloydb.connector import AsyncConnector
from google.cloud.alloydb.connector import IPTypes


async def create_sqlalchemy_engine(
Expand Down Expand Up @@ -70,7 +69,7 @@ async def getconn() -> asyncpg.Connection:
user=user,
password=password,
db=db,
ip_type=IPTypes.PUBLIC,
ip_type="PUBLIC",
)
return conn

Expand Down
3 changes: 1 addition & 2 deletions tests/system/test_pg8000_public_ip.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import sqlalchemy

from google.cloud.alloydb.connector import Connector
from google.cloud.alloydb.connector import IPTypes


def create_sqlalchemy_engine(
Expand Down Expand Up @@ -70,7 +69,7 @@ def getconn() -> pg8000.dbapi.Connection:
user=user,
password=password,
db=db,
ip_type=IPTypes.PUBLIC,
ip_type="PUBLIC",
)
return conn

Expand Down
33 changes: 33 additions & 0 deletions tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ async def test_AsyncConnector_init(credentials: FakeCredentials) -> None:
await connector.close()


async def test_AsyncConnector_init_bad_ip_type(credentials: FakeCredentials) -> None:
"""Test that AsyncConnector errors due to bad ip_type str."""
bad_ip_type = "bad-ip-type"
with pytest.raises(ValueError) as exc_info:
AsyncConnector(ip_type=bad_ip_type, credentials=credentials)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)


@pytest.mark.asyncio
async def test_AsyncConnector_context_manager(
credentials: FakeCredentials,
Expand Down Expand Up @@ -202,3 +213,25 @@ def test_synchronous_init(credentials: FakeCredentials) -> None:
"""
connector = AsyncConnector(credentials)
assert connector._keys is None


async def test_async_connect_bad_ip_type(
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
) -> None:
"""Test that AyncConnector.connect errors due to bad ip_type str."""
async with AsyncConnector(credentials=credentials) as connector:
connector._client = fake_client
bad_ip_type = "bad-ip-type"
with pytest.raises(ValueError) as exc_info:
await connector.connect(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
"asyncpg",
user="test-user",
password="test-password",
db="test-db",
ip_type=bad_ip_type,
)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)
33 changes: 33 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ def test_Connector_init(credentials: FakeCredentials) -> None:
connector.close()


def test_Connector_init_bad_ip_type(credentials: FakeCredentials) -> None:
"""Test that Connector errors due to bad ip_type str."""
bad_ip_type = "bad-ip-type"
with pytest.raises(ValueError) as exc_info:
Connector(ip_type=bad_ip_type, credentials=credentials)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)


def test_Connector_context_manager(credentials: FakeCredentials) -> None:
"""
Test to check whether the __init__ method of Connector
Expand Down Expand Up @@ -84,6 +95,28 @@ def test_connect(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -
assert connection is True


def test_connect_bad_ip_type(
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
) -> None:
"""Test that Connector.connect errors due to bad ip_type str."""
with Connector(credentials=credentials) as connector:
connector._client = fake_client
bad_ip_type = "bad-ip-type"
with pytest.raises(ValueError) as exc_info:
connector.connect(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
"pg8000",
user="test-user",
password="test-password",
db="test-db",
ip_type=bad_ip_type,
)
assert (
exc_info.value.args[0]
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
)


def test_connect_unsupported_driver(credentials: FakeCredentials) -> None:
"""
Test that connector.connect errors with unsupported database driver.
Expand Down