Skip to content

Commit 66ea090

Browse files
rsareddy0329Roja Reddy Sareddy
andauthored
Nova training support (#5489)
* Nova training support * Nova,llmft training support --------- Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com>
1 parent a8dd212 commit 66ea090

File tree

5 files changed

+605
-15
lines changed

5 files changed

+605
-15
lines changed

sagemaker-train/src/sagemaker/train/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,7 @@
6262
"mistral.mistral-large-2402-v1:0": ["us-west-2", "us-east-1", "eu-west-1"],
6363
"amazon.nova-pro-v1:0": ["us-east-1"]
6464
}
65+
66+
SM_RECIPE = "recipe"
67+
SM_RECIPE_YAML = "recipe.yaml"
68+
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"

sagemaker-train/src/sagemaker/train/model_trainer.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@
8383
SM_CODE_CONTAINER_PATH,
8484
SM_DRIVERS,
8585
SM_DRIVERS_LOCAL_PATH,
86+
SM_RECIPE,
87+
SM_RECIPE_YAML,
88+
SM_RECIPE_CONTAINER_PATH,
8689
TRAIN_SCRIPT,
8790
DEFAULT_CONTAINER_ENTRYPOINT,
8891
DEFAULT_CONTAINER_ARGUMENTS,
@@ -100,7 +103,13 @@
100103
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
101104
from sagemaker.core.telemetry.constants import Feature
102105
from sagemaker.train import logger
103-
from sagemaker.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type
106+
from sagemaker.train.sm_recipes.utils import (
107+
_get_args_from_recipe,
108+
_determine_device_type,
109+
_is_nova_recipe,
110+
_is_llmft_recipe,
111+
_load_base_recipe,
112+
)
104113

105114
from sagemaker.core.jumpstart.configs import JumpStartConfig
106115
from sagemaker.core.jumpstart.document import get_hub_content_and_document
@@ -249,6 +258,8 @@ class ModelTrainer(BaseModel):
249258
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
250259
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
251260

261+
_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
262+
_is_llmft_recipe: Optional[bool] = PrivateAttr(default=None)
252263
# Private Attributes for Recipes
253264
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
254265

@@ -573,6 +584,23 @@ def _create_training_job_args(
573584

574585
final_input_data_config = list(existing_channels.values()) + new_channels
575586

587+
if self._is_nova_recipe or self._is_llmft_recipe:
588+
for input_data in final_input_data_config:
589+
if input_data.channel_name == SM_RECIPE:
590+
raise ValueError(
591+
"Cannot use reserved channel name 'recipe' as an input channel name "
592+
" for Nova or LLMFT Recipe"
593+
)
594+
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
595+
recipe_channel = self.create_input_data_channel(
596+
channel_name=SM_RECIPE,
597+
data_source=recipe_file_path,
598+
key_prefix=input_data_key_prefix,
599+
)
600+
final_input_data_config.append(recipe_channel)
601+
if self._is_nova_recipe or self._is_llmft_recipe:
602+
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})
603+
576604
if final_input_data_config:
577605
final_input_data_config = self._get_input_data_config(
578606
final_input_data_config, input_data_key_prefix
@@ -1039,6 +1067,7 @@ def from_recipe(
10391067
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
10401068
training_input_mode: Optional[str] = "File",
10411069
environment: Optional[Dict[str, str]] = None,
1070+
hyperparameters: Optional[Union[Dict[str, Any], str]] = {},
10421071
tags: Optional[List[Tag]] = None,
10431072
sagemaker_session: Optional[Session] = None,
10441073
role: Optional[str] = None,
@@ -1136,12 +1165,20 @@ def from_recipe(
11361165
if compute.instance_type is None:
11371166
raise ValueError("Must set ``instance_type`` in Compute when using training recipes.")
11381167
device_type = _determine_device_type(compute.instance_type)
1139-
if device_type == "cpu":
1168+
recipe = _load_base_recipe(
1169+
training_recipe=training_recipe, recipe_overrides=recipe_overrides
1170+
)
1171+
is_nova = _is_nova_recipe(recipe=recipe)
1172+
is_llmft = _is_llmft_recipe(recipe=recipe)
1173+
if device_type == "cpu" and not (is_nova or is_llmft):
11401174
raise ValueError(
11411175
"Training recipes are not supported for CPU instances. "
11421176
"Please provide a GPU or Tranium instance type."
11431177
)
11441178

1179+
if training_image is None and (is_nova or is_llmft):
1180+
raise ValueError("training_image must be provided when using recipe for Nova or LLMFT")
1181+
11451182
if training_image_config and training_image is None:
11461183
raise ValueError("training_image must be provided when using training_image_config.")
11471184

@@ -1154,16 +1191,29 @@ def from_recipe(
11541191
# - distributed
11551192
# - compute
11561193
# - hyperparameters
1157-
model_trainer_args, recipe_train_dir = _get_args_from_recipe(
1158-
training_recipe=training_recipe,
1194+
model_trainer_args, tmp_dir = _get_args_from_recipe(
1195+
training_recipe=recipe,
11591196
recipe_overrides=recipe_overrides,
11601197
requirements=requirements,
11611198
compute=compute,
11621199
region_name=sagemaker_session.boto_region_name,
1200+
role=role,
11631201
)
11641202
if training_image is not None:
11651203
model_trainer_args["training_image"] = training_image
11661204

1205+
if hyperparameters and not is_nova:
1206+
logger.warning(
1207+
"Hyperparameters are not supported for general and LLMFT training recipes. "
1208+
+ "Ignoring hyperparameters input."
1209+
)
1210+
if is_nova:
1211+
if hyperparameters and isinstance(hyperparameters, str):
1212+
hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters)
1213+
model_trainer_args["hyperparameters"].update(hyperparameters)
1214+
elif hyperparameters and isinstance(hyperparameters, dict):
1215+
model_trainer_args["hyperparameters"].update(hyperparameters)
1216+
11671217
model_trainer = cls(
11681218
sagemaker_session=sagemaker_session,
11691219
role=role,
@@ -1180,7 +1230,9 @@ def from_recipe(
11801230
**model_trainer_args,
11811231
)
11821232

1183-
model_trainer._temp_recipe_train_dir = recipe_train_dir
1233+
model_trainer._is_nova_recipe = is_nova
1234+
model_trainer._is_llmft_recipe = is_llmft
1235+
model_trainer._temp_recipe_train_dir = tmp_dir
11841236
return model_trainer
11851237

11861238
@classmethod

0 commit comments

Comments
 (0)