Skip to content

Commit bebcef0

Browse files
committed
fix: consolidate region validation into centralized entry points
Move validate_region() into Session._initialize(), BaseInteractiveApp.__init__(), and DetailProfilerApp.__init__() so that region is validated automatically at object creation time, reducing the chance of future developers forgetting to add per-site checks. Remove 7 redundant per-site validate_region() calls where region already comes from a validated Session: - session_helper.py sts_regional_endpoint() - spark/processing.py - jumpstart/utils.py - telemetry_logging.py - serve/telemetry_logger.py - tensorboard.py - detail_profiler_app.py method-level call Retain 8 per-site calls where region bypasses Session (direct function params or ARN parsing): common_utils.py (2), image_retriever.py (4), image_retriever_utils.py (1), image_uris.py (2), metrics_visualizer.py (2).
1 parent 710897a commit bebcef0

11 files changed

Lines changed: 87 additions & 35 deletions

File tree

sagemaker-core/src/sagemaker/core/helper/session_helper.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ def _initialize(
228228
"Must setup local AWS configuration with a region supported by SageMaker."
229229
)
230230

231+
from sagemaker.core.region_validation import validate_region
232+
233+
validate_region(self._region_name)
234+
231235
# Make use of user_agent_extra field of the botocore_config object
232236
# to append SageMaker Python SDK specific user_agent suffix
233237
# to the current User-Agent header value from boto3
@@ -2121,9 +2125,6 @@ def sts_regional_endpoint(region):
21212125
Returns:
21222126
str: AWS STS regional endpoint
21232127
"""
2124-
from sagemaker.core.region_validation import validate_region
2125-
2126-
validate_region(region)
21272128
endpoint_data = botocore_resolver().construct_endpoint("sts", region)
21282129
if region == "il-central-1" and not endpoint_data:
21292130
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}

sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(
4343
one is created using the default AWS configuration chain.
4444
Default: ``None``
4545
"""
46+
from sagemaker.core.region_validation import validate_region
47+
4648
if isinstance(region, str):
4749
self.region = region
4850
else:
@@ -55,6 +57,7 @@ def __init__(
5557
" configuration."
5658
)
5759

60+
validate_region(self.region)
5861
self._sagemaker_client = boto3.client("sagemaker", region_name=self.region)
5962
# Used to store domain and user profile info retrieved from Studio environment.
6063
self._domain_id = None

sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def __init__(self, region: Optional[str] = None):
3838
region (str): The name of the region e.g. us-east-1. If not specified,
3939
one is created using the default AWS configuration chain.
4040
"""
41+
from sagemaker.core.region_validation import validate_region
42+
4143
if region:
4244
self.region = region
4345
else:
@@ -49,6 +51,8 @@ def __init__(self, region: Optional[str] = None):
4951
"as an input argument or setup the local AWS config."
5052
)
5153

54+
validate_region(self.region)
55+
5256
self._domain_id = None
5357
self._user_profile_name = None
5458
self._valid_domain_and_user = False
@@ -79,10 +83,6 @@ def get_app_url(self, training_job_name: Optional[str] = None):
7983
Returns:
8084
str: An unsigned URL for DetailProfiler hosted on SageMaker.
8185
"""
82-
from sagemaker.core.region_validation import validate_region
83-
84-
validate_region(self.region)
85-
8686
if self._valid_domain_and_user:
8787
url = f"https://{self._domain_id}.studio.{self.region}.sagemaker.aws/profiler/default"
8888
if training_job_name is not None:

sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,9 @@ def get_app_url(
8484
Returns:
8585
str: A URL for TensorBoard hosted on SageMaker.
8686
"""
87-
from sagemaker.core.region_validation import validate_region
88-
8987
if training_job_name is not None:
9088
self._validate_job_name(training_job_name)
9189

92-
validate_region(self.region)
93-
9490
if (
9591
self._in_studio_env
9692
and self._validate_domain_id(self._domain_id)

sagemaker-core/src/sagemaker/core/jumpstart/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,11 @@ def get_eula_url(document: HubContentDocument, sagemaker_session: Optional[Sessi
8888
if sagemaker_session is None:
8989
sagemaker_session = Session()
9090

91-
from sagemaker.core.region_validation import validate_region
92-
9391
path_parts = document.HostingEulaUri.replace("s3://", "").split("/")
9492

9593
bucket = path_parts[0]
9694
key = "/".join(path_parts[1:])
9795
region = sagemaker_session.boto_region_name
98-
validate_region(region)
9996

10097
botocore_session = sagemaker_session.boto_session._session
10198
endpoint_resolver = botocore_session.get_component("endpoint_resolver")

sagemaker-core/src/sagemaker/core/spark/processing.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,7 @@ def _is_notebook_instance(self):
570570

571571
def _get_notebook_instance_domain(self):
572572
"""Get the instance's domain."""
573-
from sagemaker.core.region_validation import validate_region
574-
575573
region = self.sagemaker_session.boto_region_name
576-
validate_region(region)
577574
with open("/opt/ml/metadata/resource-metadata.json") as file:
578575
data = json.load(file)
579576
notebook_name = data["ResourceName"]

sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,6 @@ def _construct_url(
271271
) -> str:
272272
"""Construct the URL for the telemetry request"""
273273

274-
from sagemaker.core.region_validation import validate_region
275-
276-
validate_region(region)
277274
base_url = (
278275
f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?"
279276
f"x-accountId={accountId}"

sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,16 @@ def test_detail_profiler_init_with_default_region():
120120
"""
121121
# happy case
122122
with patch(
123-
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
124-
) as region_mock:
125-
region_mock.return_value = TEST_REGION
123+
"sagemaker.core.interactive_apps.detail_profiler_app.Session"
124+
) as session_mock:
125+
session_mock.return_value.boto_region_name = TEST_REGION
126126
detail_profiler_app = DetailProfilerApp()
127127
assert detail_profiler_app.region == TEST_REGION
128128

129129
# no default region configured
130130
with patch(
131-
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
132-
) as region_mock:
133-
region_mock.side_effect = [ValueError()]
131+
"sagemaker.core.interactive_apps.detail_profiler_app.Session"
132+
) as session_mock:
133+
session_mock.side_effect = ValueError()
134134
with pytest.raises(ValueError):
135135
detail_profiler_app = DetailProfilerApp()

sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -824,16 +824,17 @@ def test_tb_init_with_default_region():
824824
"""
825825
# happy case
826826
with patch(
827-
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
828-
) as region_mock:
829-
region_mock.return_value = TEST_REGION
827+
"sagemaker.core.interactive_apps.base_interactive_app.Session"
828+
) as session_mock:
829+
session_mock.return_value.boto_region_name = TEST_REGION
830830
tb_app = TensorBoardApp()
831831
assert tb_app.region == TEST_REGION
832832

833833
# no default region configured
834834
with patch(
835-
"sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock
836-
) as region_mock:
837-
region_mock.side_effect = [ValueError()]
835+
"sagemaker.core.interactive_apps.base_interactive_app.Session"
836+
) as session_mock:
837+
session_mock.return_value.boto_region_name = PropertyMock(side_effect=ValueError())
838+
session_mock.side_effect = ValueError()
838839
with pytest.raises(ValueError):
839840
tb_app = TensorBoardApp()

sagemaker-core/tests/unit/test_region_validation.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,66 @@ def test_valid_endpoint(self, url):
159159
def test_invalid_endpoint(self, url):
160160
with pytest.raises(InvalidRegionError):
161161
validate_endpoint_url(url)
162+
163+
164+
class TestSessionRegionValidation:
165+
"""Ensure Session rejects invalid region at initialization."""
166+
167+
def test_session_rejects_malicious_region(self):
168+
from unittest.mock import patch, MagicMock
169+
170+
mock_boto_session = MagicMock()
171+
mock_boto_session.region_name = "[email protected]:443/#"
172+
173+
with pytest.raises(InvalidRegionError):
174+
from sagemaker.core.helper.session_helper import Session
175+
176+
Session(boto_session=mock_boto_session)
177+
178+
def test_session_accepts_valid_region(self):
179+
from unittest.mock import patch, MagicMock
180+
181+
mock_boto_session = MagicMock()
182+
mock_boto_session.region_name = "us-west-2"
183+
184+
with patch(
185+
"sagemaker.core.helper.session_helper.Session._initialize"
186+
) as mock_init:
187+
# Just verify validate_region doesn't raise for valid region
188+
validate_region("us-west-2")
189+
190+
191+
class TestBaseInteractiveAppRegionValidation:
192+
"""Ensure BaseInteractiveApp rejects invalid region at initialization."""
193+
194+
def test_rejects_malicious_region(self):
195+
from unittest.mock import patch
196+
197+
with pytest.raises(InvalidRegionError):
198+
from sagemaker.core.interactive_apps.tensorboard import TensorBoardApp
199+
200+
with patch("boto3.client"):
201+
TensorBoardApp(region="[email protected]:443/#")
202+
203+
def test_accepts_valid_region(self):
204+
from unittest.mock import patch, MagicMock
205+
206+
with patch("boto3.client") as mock_client, patch(
207+
"sagemaker.core.interactive_apps.base_interactive_app.BaseInteractiveApp._get_domain_and_user"
208+
):
209+
from sagemaker.core.interactive_apps.tensorboard import TensorBoardApp
210+
211+
app = TensorBoardApp(region="us-west-2")
212+
assert app.region == "us-west-2"
213+
214+
215+
class TestDetailProfilerAppRegionValidation:
216+
"""Ensure DetailProfilerApp rejects invalid region at initialization."""
217+
218+
def test_rejects_malicious_region(self):
219+
with pytest.raises(InvalidRegionError):
220+
from sagemaker.core.interactive_apps.detail_profiler_app import (
221+
DetailProfilerApp,
222+
)
223+
224+
DetailProfilerApp(region="[email protected]:443/#")

0 commit comments

Comments
 (0)