Skip to content

fix: include model channel for gated uncompressed models #5181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,18 @@ def _get_json_file(
object and None when reading from the local file system.
"""
if self._is_local_metadata_mode():
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
else:
file_content, etag = self._get_json_file_and_etag_from_s3(key)
return file_content, etag
if filetype in {
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
}:
return self._get_json_file_from_local_override(key, filetype), None
else:
JUMPSTART_LOGGER.warning(
"Local metadata mode is enabled, but the file type %s is not supported "
"for local override. Falling back to s3.",
filetype,
)
return self._get_json_file_and_etag_from_s3(key)

def _get_json_md5_hash(self, key: str):
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@
from sagemaker.jumpstart.constants import (
JUMPSTART_DEFAULT_REGION_NAME,
JUMPSTART_LOGGER,
JUMPSTART_MODEL_HUB_NAME,
TRAINING_ENTRY_POINT_SCRIPT_NAME,
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY,
JUMPSTART_MODEL_HUB_NAME,
)
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
from sagemaker.jumpstart.factory import model
Expand Down Expand Up @@ -634,10 +634,10 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
# hub_arn is by default None unless the user specifies the hub_name
# If no hub_name is specified, it is assumed the public hub
# Training platform enforces that private hub models must use model channel
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
if (
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
or is_private_hub
if is_private_hub or _model_supports_training_model_uri(
**get_model_info_default_kwargs(kwargs)
):
default_model_uri = model_uris.retrieve(
model_scope=JumpStartScriptScope.TRAINING,
Expand Down
16 changes: 12 additions & 4 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,12 +1940,20 @@ def use_inference_script_uri(self) -> bool:

def use_training_model_artifact(self) -> bool:
"""Returns True if the model should use a model uri when kicking off training job."""
# gated model never use training model artifact
if self.gated_bucket:
# old models with this environment variable present don't use model channel
if any(
self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value(
instance_type
)
for instance_type in self.supported_training_instance_types
):
return False

# even older models with training model package artifact uris present also don't use model channel
if len(self.training_model_package_artifact_uris or {}) > 0:
return False

# otherwise, return true is a training model package is not set
return len(self.training_model_package_artifact_uris or {}) == 0
return getattr(self, "training_artifact_key", None) is not None

def is_gated_model(self) -> bool:
"""Returns True if the model has a EULA key or the model bucket is gated."""
Expand Down
Empty file.
162 changes: 162 additions & 0 deletions tests/unit/sagemaker/jumpstart/factory/test_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised we didn't have this before, thanks for adding it!

#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import pytest
from unittest.mock import patch
from sagemaker.jumpstart.constants import JUMPSTART_MODEL_HUB_NAME
from sagemaker.jumpstart.factory.estimator import (
_add_model_uri_to_kwargs,
get_model_info_default_kwargs,
)
from sagemaker.jumpstart.types import JumpStartEstimatorInitKwargs
from sagemaker.jumpstart.enums import JumpStartScriptScope


class TestAddModelUriToKwargs:
@pytest.fixture
def mock_kwargs(self):
return JumpStartEstimatorInitKwargs(
model_id="test-model",
model_version="1.0.0",
instance_type="ml.m5.large",
model_uri=None,
)

@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
return_value=True,
)
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
def test_add_model_uri_to_kwargs_default_uri(
self, mock_retrieve, mock_supports_training, mock_kwargs
):
"""Test adding default model URI when none is provided."""
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
mock_retrieve.return_value = default_uri

result = _add_model_uri_to_kwargs(mock_kwargs)

mock_supports_training.assert_called_once()
mock_retrieve.assert_called_once_with(
model_scope=JumpStartScriptScope.TRAINING,
instance_type=mock_kwargs.instance_type,
**get_model_info_default_kwargs(mock_kwargs),
)
assert result.model_uri == default_uri

@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
return_value=True,
)
@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_incremental_training",
return_value=True,
)
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
def test_add_model_uri_to_kwargs_custom_uri_with_incremental(
self, mock_retrieve, mock_supports_incremental, mock_supports_training, mock_kwargs
):
"""Test using custom model URI with incremental training support."""
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
custom_uri = "s3://custom-bucket/my-model"
mock_retrieve.return_value = default_uri
mock_kwargs.model_uri = custom_uri

result = _add_model_uri_to_kwargs(mock_kwargs)

mock_supports_training.assert_called_once()
mock_supports_incremental.assert_called_once()
assert result.model_uri == custom_uri

@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
return_value=True,
)
@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_incremental_training",
return_value=False,
)
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
@patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
def test_add_model_uri_to_kwargs_custom_uri_without_incremental(
self,
mock_warning,
mock_retrieve,
mock_supports_incremental,
mock_supports_training,
mock_kwargs,
):
"""Test using custom model URI without incremental training support logs warning."""
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
custom_uri = "s3://custom-bucket/my-model"
mock_retrieve.return_value = default_uri
mock_kwargs.model_uri = custom_uri

result = _add_model_uri_to_kwargs(mock_kwargs)

mock_supports_training.assert_called_once()
mock_supports_incremental.assert_called_once()
mock_warning.assert_called_once()
assert "does not support incremental training" in mock_warning.call_args[0][0]
assert result.model_uri == custom_uri

@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
return_value=False,
)
def test_add_model_uri_to_kwargs_no_training_support(self, mock_supports_training, mock_kwargs):
"""Test when model doesn't support training model URI."""
result = _add_model_uri_to_kwargs(mock_kwargs)

mock_supports_training.assert_called_once()
assert result.model_uri is None

@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
return_value=False,
)
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
def test_add_model_uri_to_kwargs_private_hub(
self, mock_retrieve, mock_supports_training, mock_kwargs
):
"""Test when model is from a private hub."""
default_uri = "s3://jumpstart-models/training/test-model/1.0.0"
mock_retrieve.return_value = default_uri
mock_kwargs.hub_arn = "arn:aws:sagemaker:us-west-2:123456789012:hub/private-hub"

result = _add_model_uri_to_kwargs(mock_kwargs)

# Should not check if model supports training model URI for private hub
mock_supports_training.assert_not_called()
mock_retrieve.assert_called_once()
assert result.model_uri == default_uri

@patch(
"sagemaker.jumpstart.factory.estimator._model_supports_training_model_uri",
return_value=False,
)
@patch("sagemaker.jumpstart.factory.estimator.model_uris.retrieve")
def test_add_model_uri_to_kwargs_public_hub(
self, mock_retrieve, mock_supports_training, mock_kwargs
):
"""Test when model is from the public hub."""
mock_kwargs.hub_arn = (
f"arn:aws:sagemaker:us-west-2:123456789012:hub/{JUMPSTART_MODEL_HUB_NAME}"
)

result = _add_model_uri_to_kwargs(mock_kwargs)

# Should check if model supports training model URI for public hub
mock_supports_training.assert_called_once()
mock_retrieve.assert_not_called()
assert result.model_uri is None
75 changes: 75 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,3 +1288,78 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_func
assert_key = JumpStartVersionedModelId("test-model", "abc")

assert result == assert_key


@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
@patch(
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
)
def test_get_json_file_from_s3():
"""Test _get_json_file retrieves from S3 in normal mode."""
cache = JumpStartModelsCache()
test_key = "test/file/path.json"
test_json_data = {"key": "value"}
test_etag = "test-etag-123"

with patch.object(
JumpStartModelsCache,
"_get_json_file_and_etag_from_s3",
return_value=(test_json_data, test_etag),
) as mock_s3_get:
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)

mock_s3_get.assert_called_once_with(test_key)
assert result == test_json_data
assert etag == test_etag


@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
@patch(
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
)
def test_get_json_file_from_local_supported_type():
"""Test _get_json_file retrieves from local override for supported file types."""
cache = JumpStartModelsCache()
test_key = "test/file/path.json"
test_json_data = {"key": "value"}

with (
patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True),
patch.object(
JumpStartModelsCache, "_get_json_file_from_local_override", return_value=test_json_data
) as mock_local_get,
):
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)

mock_local_get.assert_called_once_with(test_key, JumpStartS3FileType.OPEN_WEIGHT_MANIFEST)
assert result == test_json_data
assert etag is None


@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region")
@patch(
"sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket"
)
def test_get_json_file_local_mode_unsupported_type():
"""Test _get_json_file falls back to S3 for unsupported file types in local mode."""
cache = JumpStartModelsCache()
test_key = "test/file/path.json"
test_json_data = {"key": "value"}
test_etag = "test-etag-123"

with (
patch.object(JumpStartModelsCache, "_is_local_metadata_mode", return_value=True),
patch.object(
JumpStartModelsCache,
"_get_json_file_and_etag_from_s3",
return_value=(test_json_data, test_etag),
) as mock_s3_get,
patch("sagemaker.jumpstart.cache.JUMPSTART_LOGGER.warning") as mock_warning,
):
result, etag = cache._get_json_file(test_key, JumpStartS3FileType.PROPRIETARY_MANIFEST)

mock_s3_get.assert_called_once_with(test_key)
mock_warning.assert_called_once()
assert "not supported for local override" in mock_warning.call_args[0][0]
assert result == test_json_data
assert etag == test_etag
71 changes: 63 additions & 8 deletions tests/unit/sagemaker/jumpstart/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
INIT_KWARGS,
)

from unittest.mock import Mock

INSTANCE_TYPE_VARIANT = JumpStartInstanceTypeVariants(
{
"regional_aliases": {
Expand Down Expand Up @@ -329,14 +331,67 @@ def test_jumpstart_model_header():
assert header1 == header3


def test_use_training_model_artifact():
specs1 = JumpStartModelSpecs(BASE_SPEC)
assert specs1.use_training_model_artifact()
specs1.gated_bucket = True
assert not specs1.use_training_model_artifact()
specs1.gated_bucket = False
specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"}
assert not specs1.use_training_model_artifact()
class TestUseTrainingModelArtifact:
@pytest.fixture
def mock_specs(self):
specs = Mock(spec=JumpStartModelSpecs)
specs.training_instance_type_variants = Mock()
specs.supported_training_instance_types = ["ml.p3.2xlarge", "ml.g4dn.xlarge"]
specs.training_model_package_artifact_uris = {}
specs.training_artifact_key = None
return specs

def test_use_training_model_artifact_with_env_var(self, mock_specs):
"""Test when instance type variants have env var values."""
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.side_effect = [
"some-value",
None,
]

result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)

assert result is False
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.assert_any_call(
"ml.p3.2xlarge"
)

def test_use_training_model_artifact_with_package_uris(self, mock_specs):
"""Test when model has training package artifact URIs."""
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
None
)
mock_specs.training_model_package_artifact_uris = {
"ml.p3.2xlarge": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/"
"llama2-13b-e155a2e0347b323fb882f1875851c5d3"
}

result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)

assert result is False

def test_use_training_model_artifact_with_artifact_key(self, mock_specs):
"""Test when model has training artifact key."""
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
None
)
mock_specs.training_model_package_artifact_uris = {}
mock_specs.training_artifact_key = "some-key"

result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)

assert result is True

def test_use_training_model_artifact_without_artifact_key(self, mock_specs):
"""Test when model has no training artifact key."""
mock_specs.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value.return_value = (
None
)
mock_specs.training_model_package_artifact_uris = {}
mock_specs.training_artifact_key = None

result = JumpStartModelSpecs.use_training_model_artifact(mock_specs)

assert result is False


def test_jumpstart_model_specs():
Expand Down