Skip to content

Commit 4b41ce8

Browse files
authored
Merge branch 'master' into add-inference-pipeline-example
2 parents 3873bb2 + ad190b9 commit 4b41ce8

File tree

6 files changed

+94
-62
lines changed

6 files changed

+94
-62
lines changed

sagemaker-core/src/sagemaker/core/helper/pipeline_variable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,7 @@ def __get_pydantic_core_schema__(cls, source_type, handler):
8080

8181
# This is a type that could be either string or pipeline variable
8282
StrPipeVar = Union[str, PipelineVariable]
83+
# This is a type that could be either integer or pipeline variable
84+
IntPipeVar = Union[int, PipelineVariable]
85+
# This is a type that could be either boolean or pipeline variable
86+
BoolPipeVar = Union[bool, PipelineVariable]

sagemaker-core/src/sagemaker/core/shapes/shapes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pydantic import BaseModel, ConfigDict, Field
1717
from typing import List, Dict, Optional, Any, Union
1818
from sagemaker.core.utils.utils import Unassigned
19-
from sagemaker.core.helper.pipeline_variable import StrPipeVar
19+
from sagemaker.core.helper.pipeline_variable import StrPipeVar, IntPipeVar, BoolPipeVar
2020

2121
# Suppress Pydantic warnings about field names shadowing parent attributes
2222
warnings.filterwarnings("ignore", message=".*shadows an attribute.*")
@@ -1324,10 +1324,10 @@ class ResourceConfig(Base):
13241324
"""
13251325

13261326
instance_type: Optional[StrPipeVar] = Unassigned()
1327-
instance_count: Optional[int] = Unassigned()
1328-
volume_size_in_gb: Optional[int] = Unassigned()
1327+
instance_count: Optional[IntPipeVar] = Unassigned()
1328+
volume_size_in_gb: Optional[IntPipeVar] = Unassigned()
13291329
volume_kms_key_id: Optional[StrPipeVar] = Unassigned()
1330-
keep_alive_period_in_seconds: Optional[int] = Unassigned()
1330+
keep_alive_period_in_seconds: Optional[IntPipeVar] = Unassigned()
13311331
capacity_reservation_ids: Optional[List[StrPipeVar]] = Unassigned()
13321332
instance_groups: Optional[List[InstanceGroup]] = Unassigned()
13331333
capacity_schedules_config: Optional[CapacitySchedulesConfig] = Unassigned()

sagemaker-core/src/sagemaker/core/training/configs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pydantic import BaseModel, model_validator, ConfigDict
2626

2727
import sagemaker.core.shapes as shapes
28-
from sagemaker.core.helper.pipeline_variable import StrPipeVar
28+
from sagemaker.core.helper.pipeline_variable import StrPipeVar, IntPipeVar, BoolPipeVar
2929

3030
# TODO: Can we add custom logic to some of these to set better defaults?
3131
from sagemaker.core.shapes import (
@@ -158,23 +158,23 @@ class Compute(shapes.ResourceConfig):
158158
instance_type (Optional[StrPipeVar]):
159159
The ML compute instance type. For information about available instance types,
160160
see https://aws.amazon.com/sagemaker/pricing/.
161-
instance_count (Optional[int]): The number of ML compute instances to use. For distributed
161+
instance_count (Optional[IntPipeVar]): The number of ML compute instances to use. For distributed
162162
training, provide a value greater than 1.
163-
volume_size_in_gb (Optional[int]):
163+
volume_size_in_gb (Optional[IntPipeVar]):
164164
The size of the ML storage volume that you want to provision. ML storage volumes store
165165
model artifacts and incremental states. Training algorithms might also use the ML
166166
storage volume for scratch space. Default: 30
167167
volume_kms_key_id (Optional[StrPipeVar]):
168168
The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage
169169
volume attached to the ML compute instance(s) that run the training job.
170-
keep_alive_period_in_seconds (Optional[int]):
170+
keep_alive_period_in_seconds (Optional[IntPipeVar]):
171171
The duration of time in seconds to retain configured resources in a warm pool for
172172
subsequent training jobs.
173173
instance_groups (Optional[List[InstanceGroup]]):
174174
A list of instance groups for heterogeneous clusters to be used in the training job.
175175
training_plan_arn (Optional[StrPipeVar]):
176176
The Amazon Resource Name (ARN) of the training plan to use for this resource configuration.
177-
enable_managed_spot_training (Optional[bool]):
177+
enable_managed_spot_training (Optional[BoolPipeVar]):
178178
To train models using managed spot training, choose True. Managed spot training
179179
provides a fully managed and scalable infrastructure for training machine learning
180180
models. this option is useful when training jobs can be interrupted and when there

sagemaker-mlops/tests/integ/test_pipeline_train_registry.py

Lines changed: 70 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from sagemaker.train import ModelTrainer
77
from sagemaker.train.configs import InputData, Compute
88
from sagemaker.core.processing import ScriptProcessor
9-
from sagemaker.core.shapes import ProcessingInput, ProcessingS3Input, ProcessingOutput, ProcessingS3Output
9+
from sagemaker.core.shapes import (
10+
ProcessingInput,
11+
ProcessingS3Input,
12+
ProcessingOutput,
13+
ProcessingS3Output,
14+
)
1015
from sagemaker.serve.model_builder import ModelBuilder
1116
from sagemaker.core.workflow.parameters import ParameterInteger, ParameterString
1217
from sagemaker.mlops.workflow.pipeline import Pipeline
@@ -37,22 +42,27 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
3742
bucket = sagemaker_session.default_bucket()
3843
prefix = "integ-test-v3-pipeline"
3944
base_job_prefix = "train-registry-job"
40-
45+
4146
# Upload abalone data to S3
42-
s3_client = boto3.client('s3')
47+
s3_client = boto3.client("s3")
4348
abalone_path = os.path.join(os.path.dirname(__file__), "data", "pipeline", "abalone.csv")
4449
s3_client.upload_file(abalone_path, bucket, f"{prefix}/input/abalone.csv")
4550
input_data_s3 = f"s3://{bucket}/{prefix}/input/abalone.csv"
46-
51+
4752
# Parameters
4853
processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
54+
training_instance_count = ParameterInteger(name="TrainingInstanceCount", default_value=1)
55+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
4956
input_data = ParameterString(
5057
name="InputDataUrl",
5158
default_value=input_data_s3,
5259
)
53-
60+
hyper_parameter_objective = ParameterString(
61+
name="TrainingObjective", default_value="reg:linear"
62+
)
63+
5464
cache_config = CacheConfig(enable_caching=True, expire_after="30d")
55-
65+
5666
# Processing step
5767
sklearn_processor = ScriptProcessor(
5868
image_uri=image_uris.retrieve(
@@ -62,13 +72,13 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
6272
py_version="py3",
6373
instance_type="ml.m5.xlarge",
6474
),
65-
instance_type="ml.m5.xlarge",
75+
instance_type=instance_type,
6676
instance_count=processing_instance_count,
6777
base_job_name=f"{base_job_prefix}-sklearn",
6878
sagemaker_session=pipeline_session,
6979
role=role,
7080
)
71-
81+
7282
processor_args = sklearn_processor.run(
7383
inputs=[
7484
ProcessingInput(
@@ -79,7 +89,7 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
7989
s3_data_type="S3Prefix",
8090
s3_input_mode="File",
8191
s3_data_distribution_type="ShardedByS3Key",
82-
)
92+
),
8393
)
8494
],
8595
outputs=[
@@ -88,36 +98,36 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
8898
s3_output=ProcessingS3Output(
8999
s3_uri=f"s3://{sagemaker_session.default_bucket()}/{prefix}/train",
90100
local_path="/opt/ml/processing/train",
91-
s3_upload_mode="EndOfJob"
92-
)
101+
s3_upload_mode="EndOfJob",
102+
),
93103
),
94104
ProcessingOutput(
95105
output_name="validation",
96106
s3_output=ProcessingS3Output(
97107
s3_uri=f"s3://{sagemaker_session.default_bucket()}/{prefix}/validation",
98108
local_path="/opt/ml/processing/validation",
99-
s3_upload_mode="EndOfJob"
100-
)
109+
s3_upload_mode="EndOfJob",
110+
),
101111
),
102112
ProcessingOutput(
103113
output_name="test",
104114
s3_output=ProcessingS3Output(
105115
s3_uri=f"s3://{sagemaker_session.default_bucket()}/{prefix}/test",
106116
local_path="/opt/ml/processing/test",
107-
s3_upload_mode="EndOfJob"
108-
)
117+
s3_upload_mode="EndOfJob",
118+
),
109119
),
110120
],
111121
code=os.path.join(os.path.dirname(__file__), "code", "pipeline", "preprocess.py"),
112122
arguments=["--input-data", input_data],
113123
)
114-
124+
115125
step_process = ProcessingStep(
116126
name="PreprocessData",
117127
step_args=processor_args,
118128
cache_config=cache_config,
119129
)
120-
130+
121131
# Training step
122132
image_uri = image_uris.retrieve(
123133
framework="xgboost",
@@ -126,47 +136,46 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
126136
py_version="py3",
127137
instance_type="ml.m5.xlarge",
128138
)
129-
139+
130140
model_trainer = ModelTrainer(
131141
training_image=image_uri,
132-
compute=Compute(instance_type="ml.m5.xlarge", instance_count=1),
142+
compute=Compute(instance_type=instance_type, instance_count=training_instance_count),
133143
base_job_name=f"{base_job_prefix}-xgboost",
134144
sagemaker_session=pipeline_session,
135145
role=role,
136146
hyperparameters={
137-
"objective": "reg:linear",
147+
"objective": hyper_parameter_objective,
138148
"num_round": 50,
139149
"max_depth": 5,
140150
},
141151
input_data_config=[
142152
InputData(
143153
channel_name="train",
144-
data_source=step_process.properties.ProcessingOutputConfig.Outputs["train"].S3Output.S3Uri,
145-
content_type="text/csv"
154+
data_source=step_process.properties.ProcessingOutputConfig.Outputs[
155+
"train"
156+
].S3Output.S3Uri,
157+
content_type="text/csv",
146158
),
147159
],
148160
)
149-
161+
150162
train_args = model_trainer.train()
151163
step_train = TrainingStep(
152164
name="TrainModel",
153165
step_args=train_args,
154166
cache_config=cache_config,
155167
)
156-
168+
157169
# Model step
158170
model_builder = ModelBuilder(
159171
s3_model_data_url=step_train.properties.ModelArtifacts.S3ModelArtifacts,
160172
image_uri=image_uri,
161173
sagemaker_session=pipeline_session,
162174
role_arn=role,
163175
)
164-
165-
step_create_model = ModelStep(
166-
name="CreateModel",
167-
step_args=model_builder.build()
168-
)
169-
176+
177+
step_create_model = ModelStep(name="CreateModel", step_args=model_builder.build())
178+
170179
# Register step
171180
model_package_group_name = f"integ-test-model-group-{uuid.uuid4().hex[:8]}"
172181
step_register_model = ModelStep(
@@ -176,33 +185,39 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
176185
content_types=["application/json"],
177186
response_types=["application/json"],
178187
inference_instances=["ml.m5.xlarge"],
179-
approval_status="Approved"
180-
)
188+
approval_status="Approved",
189+
),
181190
)
182-
191+
183192
# Pipeline
184193
pipeline_name = f"integ-test-train-registry-{uuid.uuid4().hex[:8]}"
185194
pipeline = Pipeline(
186195
name=pipeline_name,
187-
parameters=[processing_instance_count, input_data],
196+
parameters=[
197+
processing_instance_count,
198+
training_instance_count,
199+
instance_type,
200+
input_data,
201+
hyper_parameter_objective,
202+
],
188203
steps=[step_process, step_train, step_create_model, step_register_model],
189204
sagemaker_session=pipeline_session,
190205
)
191-
206+
192207
model_name = None
193208
try:
194209
# Upsert and execute pipeline
195210
pipeline.upsert(role_arn=role)
196211
execution = pipeline.start()
197-
212+
198213
# Poll execution status with 30 minute timeout
199214
timeout = 1800
200215
start_time = time.time()
201-
216+
202217
while time.time() - start_time < timeout:
203218
execution_desc = execution.describe()
204219
execution_status = execution_desc["PipelineExecutionStatus"]
205-
220+
206221
if execution_status == "Succeeded":
207222
# Get model name from execution steps
208223
steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps(
@@ -219,41 +234,47 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
219234
steps = sagemaker_session.sagemaker_client.list_pipeline_execution_steps(
220235
PipelineExecutionArn=execution_desc["PipelineExecutionArn"]
221236
)["PipelineExecutionSteps"]
222-
237+
223238
failed_steps = []
224239
for step in steps:
225240
if step.get("StepStatus") == "Failed":
226241
failure_reason = step.get("FailureReason", "Unknown reason")
227242
failed_steps.append(f"{step['StepName']}: {failure_reason}")
228-
229-
failure_details = "\n".join(failed_steps) if failed_steps else "No detailed failure information available"
230-
pytest.fail(f"Pipeline execution {execution_status}. Failed steps:\n{failure_details}")
231-
243+
244+
failure_details = (
245+
"\n".join(failed_steps)
246+
if failed_steps
247+
else "No detailed failure information available"
248+
)
249+
pytest.fail(
250+
f"Pipeline execution {execution_status}. Failed steps:\n{failure_details}"
251+
)
252+
232253
time.sleep(60)
233254
else:
234255
pytest.fail(f"Pipeline execution timed out after {timeout} seconds")
235-
256+
236257
finally:
237258
# Cleanup S3 resources
238-
s3 = boto3.resource('s3')
259+
s3 = boto3.resource("s3")
239260
bucket_obj = s3.Bucket(bucket)
240-
bucket_obj.objects.filter(Prefix=f'{prefix}/').delete()
241-
261+
bucket_obj.objects.filter(Prefix=f"{prefix}/").delete()
262+
242263
# Cleanup model
243264
if model_name:
244265
try:
245266
sagemaker_session.sagemaker_client.delete_model(ModelName=model_name)
246267
except Exception:
247268
pass
248-
269+
249270
# Cleanup model package group
250271
try:
251272
sagemaker_session.sagemaker_client.delete_model_package_group(
252273
ModelPackageGroupName=model_package_group_name
253274
)
254275
except Exception:
255276
pass
256-
277+
257278
# Cleanup pipeline
258279
try:
259280
sagemaker_session.sagemaker_client.delete_pipeline(PipelineName=pipeline_name)

sagemaker-train/src/sagemaker/train/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker.core.helper.session_helper import Session
2727
from sagemaker.core.shapes import Unassigned
2828
from sagemaker.train import logger
29+
from sagemaker.core.workflow.parameters import PipelineVariable
2930

3031

3132
def _default_bucket_and_prefix(session: Session) -> str:
@@ -172,9 +173,10 @@ def safe_serialize(data):
172173
173174
This function handles the following cases:
174175
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
175-
2. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
176+
2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable
177+
3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
176178
the JSON-encoded string using `json.dumps()`.
177-
3. If `data` cannot be serialized (e.g., a custom object), it returns the string
179+
4. If `data` cannot be serialized (e.g., a custom object), it returns the string
178180
representation of the data using `str(data)`.
179181
180182
Args:
@@ -185,6 +187,8 @@ def safe_serialize(data):
185187
"""
186188
if isinstance(data, str):
187189
return data
190+
elif isinstance(data, PipelineVariable):
191+
return data
188192
try:
189193
return json.dumps(data)
190194
except TypeError:

0 commit comments

Comments
 (0)