Skip to content

Commit b7b1d99

Browse files
feat: support ip_type as str (#267)
1 parent 9782f6e commit b7b1d99

9 files changed

+185
-17
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616

1717
[flake8]
18-
ignore = E203, E231, E266, E501, W503, ANN101, ANN401
18+
ignore = E203, E231, E266, E501, W503, ANN101, ANN102, ANN401
1919
exclude =
2020
# Exclude generated code.
2121
**/proto/**

README.md

+3-4
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,11 @@ to your instance's private IP. To change this, such as connecting to AlloyDB
381381
over a public IP address, set the `ip_type` keyword argument when initializing
382382
a `Connector()` or when calling `connector.connect()`.
383383

384-
Possible values for `ip_type` are `IPTypes.PRIVATE` (default value), and
385-
`IPTypes.PUBLIC`.
384+
Possible values for `ip_type` are `"PRIVATE"` (default value), and `"PUBLIC"`.
386385
Example:
387386

388387
```python
389-
from google.cloud.alloydb.connector import Connector, IPTypes
388+
from google.cloud.alloydb.connector import Connector
390389

391390
import sqlalchemy
392391

@@ -401,7 +400,7 @@ def getconn():
401400
user="my-user",
402401
password="my-password",
403402
db="my-db-name",
404-
ip_type=IPTypes.PUBLIC, # use public IP
403+
ip_type="PUBLIC", # use public IP
405404
)
406405

407406
# create connection pool

google/cloud/alloydb/connector/async_connector.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class AsyncConnector:
4747
alloydb_api_endpoint (str): Base URL to use when calling
4848
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
4949
enable_iam_auth (bool): Enables automatic IAM database authentication.
50-
ip_type (IPTypes): Default IP type for all AlloyDB connections.
51-
Defaults to IPTypes.PRIVATE for private IP connections.
50+
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
51+
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
5252
"""
5353

5454
def __init__(
@@ -57,14 +57,17 @@ def __init__(
5757
quota_project: Optional[str] = None,
5858
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
5959
enable_iam_auth: bool = False,
60-
ip_type: IPTypes = IPTypes.PRIVATE,
60+
ip_type: str | IPTypes = IPTypes.PRIVATE,
6161
user_agent: Optional[str] = None,
6262
) -> None:
6363
self._instances: Dict[str, Instance] = {}
6464
# initialize default params
6565
self._quota_project = quota_project
6666
self._alloydb_api_endpoint = alloydb_api_endpoint
6767
self._enable_iam_auth = enable_iam_auth
68+
# if ip_type is str, convert to IPTypes enum
69+
if isinstance(ip_type, str):
70+
ip_type = IPTypes(ip_type.upper())
6871
self._ip_type = ip_type
6972
self._user_agent = user_agent
7073
# initialize credentials
@@ -144,7 +147,10 @@ async def connect(
144147
kwargs.pop("port", None)
145148

146149
# get connection info for AlloyDB instance
147-
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
150+
ip_type: str | IPTypes = kwargs.pop("ip_type", self._ip_type)
151+
# if ip_type is str, convert to IPTypes enum
152+
if isinstance(ip_type, str):
153+
ip_type = IPTypes(ip_type.upper())
148154
ip_address, context = await instance.connection_info(ip_type)
149155

150156
# callable to be used for auto IAM authn

google/cloud/alloydb/connector/connector.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ class Connector:
5757
alloydb_api_endpoint (str): Base URL to use when calling
5858
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
5959
enable_iam_auth (bool): Enables automatic IAM database authentication.
60-
ip_type (IPTypes): Default IP type for all AlloyDB connections.
61-
Defaults to IPTypes.PRIVATE for private IP connections.
60+
ip_type (str | IPTypes): Default IP type for all AlloyDB connections.
61+
Defaults to IPTypes.PRIVATE ("PRIVATE") for private IP connections.
6262
"""
6363

6464
def __init__(
@@ -67,7 +67,7 @@ def __init__(
6767
quota_project: Optional[str] = None,
6868
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
6969
enable_iam_auth: bool = False,
70-
ip_type: IPTypes = IPTypes.PRIVATE,
70+
ip_type: str | IPTypes = IPTypes.PRIVATE,
7171
user_agent: Optional[str] = None,
7272
) -> None:
7373
# create event loop and start it in background thread
@@ -79,6 +79,9 @@ def __init__(
7979
self._quota_project = quota_project
8080
self._alloydb_api_endpoint = alloydb_api_endpoint
8181
self._enable_iam_auth = enable_iam_auth
82+
# if ip_type is str, convert to IPTypes enum
83+
if isinstance(ip_type, str):
84+
ip_type = IPTypes(ip_type.upper())
8285
self._ip_type = ip_type
8386
self._user_agent = user_agent
8487
# initialize credentials
@@ -171,7 +174,10 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
171174
kwargs.pop("port", None)
172175

173176
# get connection info for AlloyDB instance
174-
ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type)
177+
ip_type: IPTypes | str = kwargs.pop("ip_type", self._ip_type)
178+
# if ip_type is str, convert to IPTypes enum
179+
if isinstance(ip_type, str):
180+
ip_type = IPTypes(ip_type.upper())
175181
ip_address, context = await instance.connection_info(ip_type)
176182

177183
# synchronous drivers are blocking and run using executor

google/cloud/alloydb/connector/instance.py

+7
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ class IPTypes(Enum):
4949
PUBLIC: str = "PUBLIC"
5050
PRIVATE: str = "PRIVATE"
5151

52+
@classmethod
53+
def _missing_(cls, value: object) -> None:
54+
raise ValueError(
55+
f"Incorrect value for ip_type, got '{value}'. Want one of: "
56+
f"{', '.join([repr(m.value) for m in cls])}."
57+
)
58+
5259

5360
def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]:
5461
# should take form "projects/<PROJECT>/locations/<REGION>/clusters/<CLUSTER>/instances/<INSTANCE>"

tests/system/test_asyncpg_public_ip.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import sqlalchemy.ext.asyncio
2323

2424
from google.cloud.alloydb.connector import AsyncConnector
25-
from google.cloud.alloydb.connector import IPTypes
2625

2726

2827
async def create_sqlalchemy_engine(
@@ -70,7 +69,7 @@ async def getconn() -> asyncpg.Connection:
7069
user=user,
7170
password=password,
7271
db=db,
73-
ip_type=IPTypes.PUBLIC,
72+
ip_type="PUBLIC",
7473
)
7574
return conn
7675

tests/system/test_pg8000_public_ip.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import sqlalchemy
2222

2323
from google.cloud.alloydb.connector import Connector
24-
from google.cloud.alloydb.connector import IPTypes
2524

2625

2726
def create_sqlalchemy_engine(
@@ -70,7 +69,7 @@ def getconn() -> pg8000.dbapi.Connection:
7069
user=user,
7170
password=password,
7271
db=db,
73-
ip_type=IPTypes.PUBLIC,
72+
ip_type="PUBLIC",
7473
)
7574
return conn
7675

tests/unit/test_async_connector.py

+76
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
from typing import Union
1617

1718
from mock import patch
1819
from mocks import FakeAlloyDBClient
@@ -21,6 +22,7 @@
2122
import pytest
2223

2324
from google.cloud.alloydb.connector import AsyncConnector
25+
from google.cloud.alloydb.connector import IPTypes
2426

2527
ALLOYDB_API_ENDPOINT = "https://alloydb.googleapis.com"
2628

@@ -40,6 +42,58 @@ async def test_AsyncConnector_init(credentials: FakeCredentials) -> None:
4042
await connector.close()
4143

4244

45+
@pytest.mark.parametrize(
46+
"ip_type, expected",
47+
[
48+
(
49+
"private",
50+
IPTypes.PRIVATE,
51+
),
52+
(
53+
"PRIVATE",
54+
IPTypes.PRIVATE,
55+
),
56+
(
57+
IPTypes.PRIVATE,
58+
IPTypes.PRIVATE,
59+
),
60+
(
61+
"public",
62+
IPTypes.PUBLIC,
63+
),
64+
(
65+
"PUBLIC",
66+
IPTypes.PUBLIC,
67+
),
68+
(
69+
IPTypes.PUBLIC,
70+
IPTypes.PUBLIC,
71+
),
72+
],
73+
)
74+
async def test_AsyncConnector_init_ip_type(
75+
ip_type: Union[str, IPTypes], expected: IPTypes, credentials: FakeCredentials
76+
) -> None:
77+
"""
78+
Test to check whether the __init__ method of AsyncConnector
79+
properly sets ip_type.
80+
"""
81+
connector = AsyncConnector(credentials=credentials, ip_type=ip_type)
82+
assert connector._ip_type == expected
83+
connector.close()
84+
85+
86+
async def test_AsyncConnector_init_bad_ip_type(credentials: FakeCredentials) -> None:
87+
"""Test that AsyncConnector errors due to bad ip_type str."""
88+
bad_ip_type = "BAD-IP-TYPE"
89+
with pytest.raises(ValueError) as exc_info:
90+
AsyncConnector(ip_type=bad_ip_type, credentials=credentials)
91+
assert (
92+
exc_info.value.args[0]
93+
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
94+
)
95+
96+
4397
@pytest.mark.asyncio
4498
async def test_AsyncConnector_context_manager(
4599
credentials: FakeCredentials,
@@ -202,3 +256,25 @@ def test_synchronous_init(credentials: FakeCredentials) -> None:
202256
"""
203257
connector = AsyncConnector(credentials)
204258
assert connector._keys is None
259+
260+
261+
async def test_async_connect_bad_ip_type(
262+
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
263+
) -> None:
264+
"""Test that AyncConnector.connect errors due to bad ip_type str."""
265+
async with AsyncConnector(credentials=credentials) as connector:
266+
connector._client = fake_client
267+
bad_ip_type = "BAD-IP-TYPE"
268+
with pytest.raises(ValueError) as exc_info:
269+
await connector.connect(
270+
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
271+
"asyncpg",
272+
user="test-user",
273+
password="test-password",
274+
db="test-db",
275+
ip_type=bad_ip_type,
276+
)
277+
assert (
278+
exc_info.value.args[0]
279+
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
280+
)

tests/unit/test_connector.py

+76
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
import asyncio
1616
from threading import Thread
17+
from typing import Union
1718

1819
from mock import patch
1920
from mocks import FakeAlloyDBClient
2021
from mocks import FakeCredentials
2122
import pytest
2223

2324
from google.cloud.alloydb.connector import Connector
25+
from google.cloud.alloydb.connector import IPTypes
2426

2527

2628
def test_Connector_init(credentials: FakeCredentials) -> None:
@@ -36,6 +38,58 @@ def test_Connector_init(credentials: FakeCredentials) -> None:
3638
connector.close()
3739

3840

41+
def test_Connector_init_bad_ip_type(credentials: FakeCredentials) -> None:
42+
"""Test that Connector errors due to bad ip_type str."""
43+
bad_ip_type = "BAD-IP-TYPE"
44+
with pytest.raises(ValueError) as exc_info:
45+
Connector(ip_type=bad_ip_type, credentials=credentials)
46+
assert (
47+
exc_info.value.args[0]
48+
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
49+
)
50+
51+
52+
@pytest.mark.parametrize(
53+
"ip_type, expected",
54+
[
55+
(
56+
"private",
57+
IPTypes.PRIVATE,
58+
),
59+
(
60+
"PRIVATE",
61+
IPTypes.PRIVATE,
62+
),
63+
(
64+
IPTypes.PRIVATE,
65+
IPTypes.PRIVATE,
66+
),
67+
(
68+
"public",
69+
IPTypes.PUBLIC,
70+
),
71+
(
72+
"PUBLIC",
73+
IPTypes.PUBLIC,
74+
),
75+
(
76+
IPTypes.PUBLIC,
77+
IPTypes.PUBLIC,
78+
),
79+
],
80+
)
81+
def test_Connector_init_ip_type(
82+
ip_type: Union[str, IPTypes], expected: IPTypes, credentials: FakeCredentials
83+
) -> None:
84+
"""
85+
Test to check whether the __init__ method of Connector
86+
properly sets ip_type.
87+
"""
88+
connector = Connector(credentials=credentials, ip_type=ip_type)
89+
assert connector._ip_type == expected
90+
connector.close()
91+
92+
3993
def test_Connector_context_manager(credentials: FakeCredentials) -> None:
4094
"""
4195
Test to check whether the __init__ method of Connector
@@ -84,6 +138,28 @@ def test_connect(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -
84138
assert connection is True
85139

86140

141+
def test_connect_bad_ip_type(
142+
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
143+
) -> None:
144+
"""Test that Connector.connect errors due to bad ip_type str."""
145+
with Connector(credentials=credentials) as connector:
146+
connector._client = fake_client
147+
bad_ip_type = "BAD-IP-TYPE"
148+
with pytest.raises(ValueError) as exc_info:
149+
connector.connect(
150+
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
151+
"pg8000",
152+
user="test-user",
153+
password="test-password",
154+
db="test-db",
155+
ip_type=bad_ip_type,
156+
)
157+
assert (
158+
exc_info.value.args[0]
159+
== f"Incorrect value for ip_type, got '{bad_ip_type}'. Want one of: 'PUBLIC', 'PRIVATE'."
160+
)
161+
162+
87163
def test_connect_unsupported_driver(credentials: FakeCredentials) -> None:
88164
"""
89165
Test that connector.connect errors with unsupported database driver.

0 commit comments

Comments
 (0)