Skip to content

Commit ade4453

Browse files
committed
chore: add unit tests
1 parent 7531524 commit ade4453

File tree

4 files changed

+245
-9
lines changed

4 files changed

+245
-9
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
JUMPSTART_LOGGER,
5757
TRAINING_ENTRY_POINT_SCRIPT_NAME,
5858
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
59-
JUMPSTART_MODEL_HUB_NAME,
6059
)
6160
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
6261
from sagemaker.jumpstart.factory import model
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import pytest
2+
from unittest.mock import patch, Mock
3+
from sagemaker.jumpstart.factory.estimator import (
4+
_add_model_uri_to_kwargs,
5+
get_model_info_default_kwargs,
6+
)
7+
from sagemaker.jumpstart.types import JumpStartEstimatorInitKwargs
8+
from sagemaker.jumpstart.enums import JumpStartScriptScope
9+
10+
11+
class TestAddModelUriToKwargs:
12+
@pytest.fixture
13+
def mock_kwargs(self):
14+
return JumpStartEstimatorInitKwargs(
15+
model_id="test-model",
16+
model_version="1.0.0",
17+
instance_type="ml.m5.large",
18+
model_uri=None,
19+
)
20+
21+
@patch(
22+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
23+
return_value=True,
24+
)
25+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
26+
def test_add_model_uri_to_kwargs_default_uri(
27+
self, mock_retrieve, mock_supports_training, mock_kwargs
28+
):
29+
"""Test adding default model URI when none is provided."""
30+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
31+
mock_retrieve.return_value = default_uri
32+
33+
result = _add_model_uri_to_kwargs(mock_kwargs)
34+
35+
mock_supports_training.assert_called_once()
36+
mock_retrieve.assert_called_once_with(
37+
model_scope=JumpStartScriptScope.TRAINING,
38+
instance_type=mock_kwargs.instance_type,
39+
**get_model_info_default_kwargs(mock_kwargs),
40+
)
41+
assert result.model_uri == default_uri
42+
43+
@patch(
44+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
45+
return_value=True,
46+
)
47+
@patch(
48+
"sagemaker.jumpstart.factory.estimator._model_supports_incremental_training",
49+
return_value=True,
50+
)
51+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
52+
def test_add_model_uri_to_kwargs_custom_uri_with_incremental(
53+
self, mock_retrieve, mock_supports_incremental, mock_supports_training, mock_kwargs
54+
):
55+
"""Test using custom model URI with incremental training support."""
56+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
57+
custom_uri = "s3://custom-bucket/my-model"
58+
mock_retrieve.return_value = default_uri
59+
mock_kwargs.model_uri = custom_uri
60+
61+
result = _add_model_uri_to_kwargs(mock_kwargs)
62+
63+
mock_supports_training.assert_called_once()
64+
mock_supports_incremental.assert_called_once()
65+
assert result.model_uri == custom_uri
66+
67+
@patch(
68+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
69+
return_value=True,
70+
)
71+
@patch(
72+
"sagemaker.jumpstart.factory.estimator._model_supports_incremental_training",
73+
return_value=False,
74+
)
75+
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
76+
@patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
77+
def test_add_model_uri_to_kwargs_custom_uri_without_incremental(
78+
self,
79+
mock_warning,
80+
mock_retrieve,
81+
mock_supports_incremental,
82+
mock_supports_training,
83+
mock_kwargs,
84+
):
85+
"""Test using custom model URI without incremental training support logs warning."""
86+
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
87+
custom_uri = "s3://custom-bucket/my-model"
88+
mock_retrieve.return_value = default_uri
89+
mock_kwargs.model_uri = custom_uri
90+
91+
result = _add_model_uri_to_kwargs(mock_kwargs)
92+
93+
mock_supports_training.assert_called_once()
94+
mock_supports_incremental.assert_called_once()
95+
mock_warning.assert_called_once()
96+
assert "does not support incremental training" in mock_warning.call_args[0][0]
97+
assert result.model_uri == custom_uri
98+
99+
@patch(
100+
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
101+
return_value=False,
102+
)
103+
def test_add_model_uri_to_kwargs_no_training_support(self, mock_supports_training, mock_kwargs):
104+
"""Test when model doesn't support training model URI."""
105+
result = _add_model_uri_to_kwargs(mock_kwargs)
106+
107+
mock_supports_training.assert_called_once()
108+
assert result.model_uri is None

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,3 +1288,78 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_func
12881288
assert_key = JumpStartVersionedModelId("test-model", "abc")
12891289

12901290
assert result == assert_key
1291+
1292+
1293+
@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
1294+
@patch(
1295+
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
1296+
)
1297+
def test_get_json_file_from_s3():
1298+
"""Test _get_json_file retrieves from S3 in normal mode."""
1299+
cache = JumpStartModelsCache()
1300+
test_key = "test/file/path.json"
1301+
test_json_data = {"key": "value"}
1302+
test_etag = "test-etag-123"
1303+
1304+
with patch.object(
1305+
JumpStartModelsCache,
1306+
"_get_json_file_and_etag_from_s3",
1307+
return_value=(test_json_data, test_etag),
1308+
) as mock_s3_get:
1309+
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)
1310+
1311+
mock_s3_get.assert_called_once_with(test_key)
1312+
assert result == test_json_data
1313+
assert etag == test_etag
1314+
1315+
1316+
@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
1317+
@patch(
1318+
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
1319+
)
1320+
def test_get_json_file_from_local_supported_type():
1321+
"""Test _get_json_file retrieves from local override for supported file types."""
1322+
cache = JumpStartModelsCache()
1323+
test_key = "test/file/path.json"
1324+
test_json_data = {"key": "value"}
1325+
1326+
with (
1327+
patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True),
1328+
patch.object(
1329+
JumpStartModelsCache, "_get_json_file_from_local_override", return_value=test_json_data
1330+
) as mock_local_get,
1331+
):
1332+
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)
1333+
1334+
mock_local_get.assert_called_once_with(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)
1335+
assert result == test_json_data
1336+
assert etag is None
1337+
1338+
1339+
@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
1340+
@patch(
1341+
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
1342+
)
1343+
def test_get_json_file_local_mode_unsupported_type():
1344+
"""Test _get_json_file falls back to S3 for unsupported file types in local mode."""
1345+
cache = JumpStartModelsCache()
1346+
test_key = "test/file/path.json"
1347+
test_json_data = {"key": "value"}
1348+
test_etag = "test-etag-123"
1349+
1350+
with (
1351+
patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True),
1352+
patch.object(
1353+
JumpStartModelsCache,
1354+
"_get_json_file_and_etag_from_s3",
1355+
return_value=(test_json_data, test_etag),
1356+
) as mock_s3_get,
1357+
patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") as mock_warning,
1358+
):
1359+
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.PROPRIETARY_MANIFEST)
1360+
1361+
mock_s3_get.assert_called_once_with(test_key)
1362+
mock_warning.assert_called_once()
1363+
assert "not supported for local override" in mock_warning.call_args[0][0]
1364+
assert result == test_json_data
1365+
assert etag == test_etag

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
INIT_KWARGS,
4040
)
4141

42+
from unittest.mock import Mock
43+
4244
INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants(
4345
{
4446
"regional_aliases": {
@@ -329,14 +331,66 @@ def test_jumpstart_model_header():
329331
assert header1 == header3
330332

331333

332-
def test_use_training_model_artifact():
333-
specs1 = JumpStartModelSpecs(BASE_SPEC)
334-
assert specs1.use_training_model_artifact()
335-
specs1.gated_bucket = True
336-
assert not specs1.use_training_model_artifact()
337-
specs1.gated_bucket = False
338-
specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"}
339-
assert not specs1.use_training_model_artifact()
334+
class TestUseTrainingModelArtifact:
335+
@pytest.fixture
336+
def mock_specs(self):
337+
specs = Mock(spec=JumpStartModelSpecs)
338+
specs.training_instance_type_variants = Mock()
339+
specs.supported_training_instance_types = ["ml.p3.2xlarge", "ml.g4dn.xlarge"]
340+
specs.training_model_package_artifact_uris = {}
341+
specs.training_artifact_key = None
342+
return specs
343+
344+
def test_use_training_model_artifact_with_env_var(self, mock_specs):
345+
"""Test when instance type variants have env var values."""
346+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.side_effect = [
347+
"some-value",
348+
None,
349+
]
350+
351+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
352+
353+
assert result is False
354+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.assert_any_call(
355+
"ml.p3.2xlarge"
356+
)
357+
358+
def test_use_training_model_artifact_with_package_uris(self, mock_specs):
359+
"""Test when model has training package artifact URIs."""
360+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
361+
None
362+
)
363+
mock_specs.training_model_package_artifact_uris = {
364+
"ml.p3.2xlarge": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/llama2-13b-e155a2e0347b323fb882f1875851c5d3"
365+
}
366+
367+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
368+
369+
assert result is False
370+
371+
def test_use_training_model_artifact_with_artifact_key(self, mock_specs):
372+
"""Test when model has training artifact key."""
373+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
374+
None
375+
)
376+
mock_specs.training_model_package_artifact_uris = {}
377+
mock_specs.training_artifact_key = "some-key"
378+
379+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
380+
381+
assert result is True
382+
383+
def test_use_training_model_artifact_without_artifact_key(self, mock_specs):
384+
"""Test when model has no training artifact key."""
385+
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
386+
None
387+
)
388+
mock_specs.training_model_package_artifact_uris = {}
389+
mock_specs.training_artifact_key = None
390+
391+
result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)
392+
393+
assert result is False
340394

341395

342396
def test_jumpstart_model_specs():

0 commit comments

Comments
 (0)