Skip to content

Commit 51321c3

Browse files
committed
chore: always use model channel for private hub models, add unit tests
1 parent ed67de5 commit 51321c3

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from sagemaker.jumpstart.constants import (
5555
JUMPSTART_DEFAULT_REGION_NAME,
5656
JUMPSTART_LOGGER,
57+
JUMPSTART_MODEL_HUB_NAME,
5758
TRAINING_ENTRY_POINT_SCRIPT_NAME,
5859
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
5960
)
@@ -631,7 +632,13 @@ def _add_model_reference_arn_to_kwargs(
631632

632633
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
633634
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
634-
if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
635+
# hub_arn is by default None unless the user specifies the hub_name
636+
# If no hub_name is specified, it is assumed the public hub
637+
# Training platform enforces that private hub models must use model channel
638+
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
639+
if is_private_hub or _model_supports_training_model_uri(
640+
**get_model_info_default_kwargs(kwargs)
641+
):
635642
default_model_uri = model_uris.retrieve(
636643
model_scope=JumpStartScriptScope.TRAINING,
637644
instance_type=kwargs.instance_type,

tests/unit/sagemaker/jumpstart/factory/test_estimator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414
import pytest
1515
from unittest.mock import patch
16+
from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
1617
from sagemaker.jumpstart.factory.estimator import (
1718
_add_model_uri_to_kwargs,
1819
get_model_info_default_kwargs,
@@ -119,3 +120,43 @@ def test_add_model_uri_to_kwargs_no_training_support(self, mock_supports_trainin
119120

120121
mock_supports_training.assert_called_once()
121122
assert result.model_uri is None
123+
124+
@patch(
125+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
126+
return_value=False,
127+
)
128+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
129+
def test_add_model_uri_to_kwargs_private_hub(
130+
self, mock_retrieve, mock_supports_training, mock_kwargs
131+
):
132+
"""Test when model is from a private hub."""
133+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
134+
mock_retrieve.return_value = default_uri
135+
mock_kwargs.hub_arn = "arn:aws:sagemaker:us-west-2:123456789012:hub/private-hub"
136+
137+
result = _add_model_uri_to_kwargs(mock_kwargs)
138+
139+
# Should not check if model supports training model URI for private hub
140+
mock_supports_training.assert_not_called()
141+
mock_retrieve.assert_called_once()
142+
assert result.model_uri == default_uri
143+
144+
@patch(
145+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
146+
return_value=False,
147+
)
148+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
149+
def test_add_model_uri_to_kwargs_public_hub(
150+
self, mock_retrieve, mock_supports_training, mock_kwargs
151+
):
152+
"""Test when model is from the public hub."""
153+
mock_kwargs.hub_arn = (
154+
f"arn:aws:sagemaker:us-west-2:123456789012:hub/{JUMPSTART_MODEL_HUB_NAME}"
155+
)
156+
157+
result = _add_model_uri_to_kwargs(mock_kwargs)
158+
159+
# Should check if model supports training model URI for public hub
160+
mock_supports_training.assert_called_once()
161+
mock_retrieve.assert_not_called()
162+
assert result.model_uri is None

0 commit comments

Comments
 (0)