Skip to content

Commit 15934bd

Browse files
feat: add domain name validation (#1246)
1 parent eeac5f7 commit 15934bd

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

google/cloud/sql/connector/connection_name.py

+10
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
# Additionally, we have to support legacy "domain-scoped" projects
2020
# (e.g. "google.com:PROJECT")
2121
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))
22+
# The domain name pattern in accordance with RFC 1035, RFC 1123 and RFC 2181.
23+
DOMAIN_NAME_REGEX = re.compile(
24+
r"^(?:[_a-z0-9](?:[_a-z0-9-]{0,61}[a-z0-9])?\.)+(?:[a-z](?:[a-z0-9-]{0,61}[a-z0-9])?)?$"
25+
)
2226

2327

2428
@dataclass
@@ -39,6 +43,12 @@ def __str__(self) -> str:
3943
return f"{self.project}:{self.region}:{self.instance_name}"
4044

4145

46+
def _is_valid_domain(domain_name: str) -> bool:
47+
if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None:
48+
return False
49+
return True
50+
51+
4252
def _parse_connection_name(connection_name: str) -> ConnectionName:
4353
return _parse_connection_name_with_domain_name(connection_name, "")
4454

google/cloud/sql/connector/resolver.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from google.cloud.sql.connector.connection_name import (
1818
_parse_connection_name_with_domain_name,
1919
)
20+
from google.cloud.sql.connector.connection_name import _is_valid_domain
2021
from google.cloud.sql.connector.connection_name import _parse_connection_name
2122
from google.cloud.sql.connector.connection_name import ConnectionName
2223
from google.cloud.sql.connector.exceptions import DnsResolutionError
@@ -40,8 +41,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore
4041
conn_name = _parse_connection_name(dns)
4142
except ValueError:
4243
# The connection name was not project:region:instance format.
43-
# Attempt to query a TXT record to get connection name.
44-
conn_name = await self.query_dns(dns)
44+
# Check if connection name is a valid DNS domain name
45+
if _is_valid_domain(dns):
46+
# Attempt to query a TXT record to get connection name.
47+
conn_name = await self.query_dns(dns)
48+
else:
49+
raise ValueError(
50+
"Arg `instance_connection_string` must have "
51+
"format: PROJECT:REGION:INSTANCE or be a valid DNS domain "
52+
f"name, got {dns}."
53+
)
4554
return conn_name
4655

4756
async def query_dns(self, dns: str) -> ConnectionName:

tests/unit/test_connection_name.py

+38
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from google.cloud.sql.connector.connection_name import (
1818
_parse_connection_name_with_domain_name,
1919
)
20+
from google.cloud.sql.connector.connection_name import _is_valid_domain
2021
from google.cloud.sql.connector.connection_name import _parse_connection_name
2122
from google.cloud.sql.connector.connection_name import ConnectionName
2223

@@ -96,3 +97,40 @@ def test_parse_connection_name_with_domain_name(
9697
assert expected == _parse_connection_name_with_domain_name(
9798
connection_name, domain_name
9899
)
100+
101+
102+
@pytest.mark.parametrize(
103+
"domain_name, expected",
104+
[
105+
(
106+
"prod-db.mycompany.example.com",
107+
True,
108+
),
109+
(
110+
"example.com.", # trailing dot
111+
True,
112+
),
113+
(
114+
"-example.com.", # leading hyphen
115+
False,
116+
),
117+
(
118+
"example", # missing TLD
119+
False,
120+
),
121+
(
122+
"127.0.0.1", # IPv4 address
123+
False,
124+
),
125+
(
126+
"0:0:0:0:0:0:0:1", # IPv6 address
127+
False,
128+
),
129+
],
130+
)
131+
def test_is_valid_domain(domain_name: str, expected: bool) -> None:
132+
"""
133+
Test that _is_valid_domain works correctly for
134+
parsing domain names.
135+
"""
136+
assert expected == _is_valid_domain(domain_name)

0 commit comments

Comments
 (0)