8383 SM_CODE_CONTAINER_PATH ,
8484 SM_DRIVERS ,
8585 SM_DRIVERS_LOCAL_PATH ,
86+ SM_RECIPE ,
87+ SM_RECIPE_YAML ,
88+ SM_RECIPE_CONTAINER_PATH ,
8689 TRAIN_SCRIPT ,
8790 DEFAULT_CONTAINER_ENTRYPOINT ,
8891 DEFAULT_CONTAINER_ARGUMENTS ,
100103from sagemaker .core .telemetry .telemetry_logging import _telemetry_emitter
101104from sagemaker .core .telemetry .constants import Feature
102105from sagemaker .train import logger
103- from sagemaker .train .sm_recipes .utils import _get_args_from_recipe , _determine_device_type
106+ from sagemaker .train .sm_recipes .utils import (
107+ _get_args_from_recipe ,
108+ _determine_device_type ,
109+ _is_nova_recipe ,
110+ _is_llmft_recipe ,
111+ _load_base_recipe ,
112+ )
104113
105114from sagemaker .core .jumpstart .configs import JumpStartConfig
106115from sagemaker .core .jumpstart .document import get_hub_content_and_document
@@ -249,6 +258,8 @@ class ModelTrainer(BaseModel):
249258 _remote_debug_config : Optional [RemoteDebugConfig ] = PrivateAttr (default = None )
250259 _metric_definitions : Optional [List [MetricDefinition ]] = PrivateAttr (default = None )
251260
261+ _is_nova_recipe : Optional [bool ] = PrivateAttr (default = None )
262+ _is_llmft_recipe : Optional [bool ] = PrivateAttr (default = None )
252263 # Private Attributes for Recipes
253264 _temp_recipe_train_dir : Optional [TemporaryDirectory ] = PrivateAttr (default = None )
254265
@@ -573,6 +584,23 @@ def _create_training_job_args(
573584
574585 final_input_data_config = list (existing_channels .values ()) + new_channels
575586
587+ if self ._is_nova_recipe or self ._is_llmft_recipe :
588+ for input_data in final_input_data_config :
589+ if input_data .channel_name == SM_RECIPE :
590+ raise ValueError (
591+ "Cannot use reserved channel name 'recipe' as an input channel name "
592+ " for Nova or LLMFT Recipe"
593+ )
594+ recipe_file_path = os .path .join (self ._temp_recipe_train_dir .name , SM_RECIPE_YAML )
595+ recipe_channel = self .create_input_data_channel (
596+ channel_name = SM_RECIPE ,
597+ data_source = recipe_file_path ,
598+ key_prefix = input_data_key_prefix ,
599+ )
600+ final_input_data_config .append (recipe_channel )
601+ if self ._is_nova_recipe or self ._is_llmft_recipe :
602+ self .hyperparameters .update ({"sagemaker_recipe_local_path" : SM_RECIPE_CONTAINER_PATH })
603+
576604 if final_input_data_config :
577605 final_input_data_config = self ._get_input_data_config (
578606 final_input_data_config , input_data_key_prefix
@@ -1039,6 +1067,7 @@ def from_recipe(
10391067 checkpoint_config : Optional [shapes .CheckpointConfig ] = None ,
10401068 training_input_mode : Optional [str ] = "File" ,
10411069 environment : Optional [Dict [str , str ]] = None ,
1070+ hyperparameters : Optional [Union [Dict [str , Any ], str ]] = {},
10421071 tags : Optional [List [Tag ]] = None ,
10431072 sagemaker_session : Optional [Session ] = None ,
10441073 role : Optional [str ] = None ,
@@ -1136,12 +1165,20 @@ def from_recipe(
11361165 if compute .instance_type is None :
11371166 raise ValueError ("Must set ``instance_type`` in Compute when using training recipes." )
11381167 device_type = _determine_device_type (compute .instance_type )
1139- if device_type == "cpu" :
1168+ recipe = _load_base_recipe (
1169+ training_recipe = training_recipe , recipe_overrides = recipe_overrides
1170+ )
1171+ is_nova = _is_nova_recipe (recipe = recipe )
1172+ is_llmft = _is_llmft_recipe (recipe = recipe )
1173+ if device_type == "cpu" and not (is_nova or is_llmft ):
11401174 raise ValueError (
11411175 "Training recipes are not supported for CPU instances. "
11421176 "Please provide a GPU or Tranium instance type."
11431177 )
11441178
1179+ if training_image is None and (is_nova or is_llmft ):
1180+ raise ValueError ("training_image must be provided when using recipe for Nova or LLMFT" )
1181+
11451182 if training_image_config and training_image is None :
11461183 raise ValueError ("training_image must be provided when using training_image_config." )
11471184
@@ -1154,16 +1191,29 @@ def from_recipe(
11541191 # - distributed
11551192 # - compute
11561193 # - hyperparameters
1157- model_trainer_args , recipe_train_dir = _get_args_from_recipe (
1158- training_recipe = training_recipe ,
1194+ model_trainer_args , tmp_dir = _get_args_from_recipe (
1195+ training_recipe = recipe ,
11591196 recipe_overrides = recipe_overrides ,
11601197 requirements = requirements ,
11611198 compute = compute ,
11621199 region_name = sagemaker_session .boto_region_name ,
1200+ role = role ,
11631201 )
11641202 if training_image is not None :
11651203 model_trainer_args ["training_image" ] = training_image
11661204
1205+ if hyperparameters and not is_nova :
1206+ logger .warning (
1207+ "Hyperparameters are not supported for general and LLMFT training recipes. "
1208+ + "Ignoring hyperparameters input."
1209+ )
1210+ if is_nova :
1211+ if hyperparameters and isinstance (hyperparameters , str ):
1212+ hyperparameters = cls ._validate_and_load_hyperparameters_file (hyperparameters )
1213+ model_trainer_args ["hyperparameters" ].update (hyperparameters )
1214+ elif hyperparameters and isinstance (hyperparameters , dict ):
1215+ model_trainer_args ["hyperparameters" ].update (hyperparameters )
1216+
11671217 model_trainer = cls (
11681218 sagemaker_session = sagemaker_session ,
11691219 role = role ,
@@ -1180,7 +1230,9 @@ def from_recipe(
11801230 ** model_trainer_args ,
11811231 )
11821232
1183- model_trainer ._temp_recipe_train_dir = recipe_train_dir
1233+ model_trainer ._is_nova_recipe = is_nova
1234+ model_trainer ._is_llmft_recipe = is_llmft
1235+ model_trainer ._temp_recipe_train_dir = tmp_dir
11841236 return model_trainer
11851237
11861238 @classmethod
0 commit comments