Skip to content

Commit 7531524

Browse files
committed
fix: include model channel for gated uncompressed models
1 parent ddc54d2 commit 7531524

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,18 @@ def _get_json_file(
372372
object and None when reading from the local file system.
373373
"""
374374
if self._is_local_metadata_mode():
375-
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
376-
else:
377-
file_content, etag = self._get_json_file_and_etag_from_s3(key)
378-
return file_content, etag
375+
if filetype in {
376+
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
377+
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
378+
}:
379+
return self._get_json_file_from_local_override(key, filetype), None
380+
else:
381+
JUMPSTART_LOGGER.warning(
382+
"Local metadata mode is enabled, but the file type %s is not supported "
383+
"for local override. Falling back to s3.",
384+
filetype,
385+
)
386+
return self._get_json_file_and_etag_from_s3(key)
379387

380388
def _get_json_md5_hash(self, key: str):
381389
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -632,13 +632,7 @@ def _add_model_reference_arn_to_kwargs(
632632

633633
def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartEstimatorInitKwargs:
634634
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
635-
# hub_arn is by default None unless the user specifies the hub_name
636-
# If no hub_name is specified, it is assumed the public hub
637-
is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs.hub_arn if kwargs.hub_arn else False
638-
if (
639-
_model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs))
640-
or is_private_hub
641-
):
635+
if _model_supports_training_model_uri(**get_model_info_default_kwargs(kwargs)):
642636
default_model_uri = model_uris.retrieve(
643637
model_scope=JumpStartScriptScope.TRAINING,
644638
instance_type=kwargs.instance_type,

src/sagemaker/jumpstart/types.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1940,12 +1940,20 @@ def use_inference_script_uri(self) -> bool:
19401940

19411941
def use_training_model_artifact(self) -> bool:
19421942
"""Returns True if the model should use a model uri when kicking off training job."""
1943-
# gated model never use training model artifact
1944-
if self.gated_bucket:
1943+
# old models with this environment variable present don't use model channel
1944+
if any(
1945+
self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value(
1946+
instance_type
1947+
)
1948+
for instance_type in self.supported_training_instance_types
1949+
):
1950+
return False
1951+
1952+
# even older models with training model package artifact uris present also don't use model channel
1953+
if len(self.training_model_package_artifact_uris or {}) > 0:
19451954
return False
19461955

1947-
# otherwise, return true is a training model package is not set
1948-
return len(self.training_model_package_artifact_uris or {}) == 0
1956+
return getattr(self, "training_artifact_key", None) is not None
19491957

19501958
def is_gated_model(self) -> bool:
19511959
"""Returns True if the model has a EULA key or the model bucket is gated."""

0 commit comments

Comments
 (0)