File tree 3 files changed +25
-15
lines changed
3 files changed +25
-15
lines changed Original file line number Diff line number Diff line change @@ -372,10 +372,18 @@ def _get_json_file(
372
372
object and None when reading from the local file system.
373
373
"""
374
374
if self ._is_local_metadata_mode ():
375
- file_content , etag = self ._get_json_file_from_local_override (key , filetype ), None
376
- else :
377
- file_content , etag = self ._get_json_file_and_etag_from_s3 (key )
378
- return file_content , etag
375
+ if filetype in {
376
+ JumpStartS3FileType .OPEN_WEIGHT_MANIFEST ,
377
+ JumpStartS3FileType .OPEN_WEIGHT_SPECS ,
378
+ }:
379
+ return self ._get_json_file_from_local_override (key , filetype ), None
380
+ else :
381
+ JUMPSTART_LOGGER .warning (
382
+ "Local metadata mode is enabled, but the file type %s is not supported "
383
+ "for local override. Falling back to s3." ,
384
+ filetype ,
385
+ )
386
+ return self ._get_json_file_and_etag_from_s3 (key )
379
387
380
388
def _get_json_md5_hash (self , key : str ):
381
389
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.
Original file line number Diff line number Diff line change @@ -632,13 +632,7 @@ def _add_model_reference_arn_to_kwargs(
632
632
633
633
def _add_model_uri_to_kwargs (kwargs : JumpStartEstimatorInitKwargs ) -> JumpStartEstimatorInitKwargs :
634
634
"""Sets model uri in kwargs based on default or override, returns full kwargs."""
635
- # hub_arn is by default None unless the user specifies the hub_name
636
- # If no hub_name is specified, it is assumed the public hub
637
- is_private_hub = JUMPSTART_MODEL_HUB_NAME not in kwargs .hub_arn if kwargs .hub_arn else False
638
- if (
639
- _model_supports_training_model_uri (** get_model_info_default_kwargs (kwargs ))
640
- or is_private_hub
641
- ):
635
+ if _model_supports_training_model_uri (** get_model_info_default_kwargs (kwargs )):
642
636
default_model_uri = model_uris .retrieve (
643
637
model_scope = JumpStartScriptScope .TRAINING ,
644
638
instance_type = kwargs .instance_type ,
Original file line number Diff line number Diff line change @@ -1940,12 +1940,20 @@ def use_inference_script_uri(self) -> bool:
1940
1940
1941
1941
def use_training_model_artifact (self ) -> bool :
1942
1942
"""Returns True if the model should use a model uri when kicking off training job."""
1943
- # gated model never use training model artifact
1944
- if self .gated_bucket :
1943
+ # old models with this environment variable present don't use model channel
1944
+ if any (
1945
+ self .training_instance_type_variants .get_instance_specific_gated_model_key_env_var_value (
1946
+ instance_type
1947
+ )
1948
+ for instance_type in self .supported_training_instance_types
1949
+ ):
1950
+ return False
1951
+
1952
+ # even older models with training model package artifact uris present also don't use model channel
1953
+ if len (self .training_model_package_artifact_uris or {}) > 0 :
1945
1954
return False
1946
1955
1947
- # otherwise, return true is a training model package is not set
1948
- return len (self .training_model_package_artifact_uris or {}) == 0
1956
+ return getattr (self , "training_artifact_key" , None ) is not None
1949
1957
1950
1958
def is_gated_model (self ) -> bool :
1951
1959
"""Returns True if the model has a EULA key or the model bucket is gated."""
You can’t perform that action at this time.
0 commit comments