66from sagemaker .train import ModelTrainer
77from sagemaker .train .configs import InputData , Compute
88from sagemaker .core .processing import ScriptProcessor
9- from sagemaker .core .shapes import ProcessingInput , ProcessingS3Input , ProcessingOutput , ProcessingS3Output
9+ from sagemaker .core .shapes import (
10+ ProcessingInput ,
11+ ProcessingS3Input ,
12+ ProcessingOutput ,
13+ ProcessingS3Output ,
14+ )
1015from sagemaker .serve .model_builder import ModelBuilder
1116from sagemaker .core .workflow .parameters import ParameterInteger , ParameterString
1217from sagemaker .mlops .workflow .pipeline import Pipeline
@@ -37,22 +42,27 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
3742 bucket = sagemaker_session .default_bucket ()
3843 prefix = "integ-test-v3-pipeline"
3944 base_job_prefix = "train-registry-job"
40-
45+
4146 # Upload abalone data to S3
42- s3_client = boto3 .client ('s3' )
47+ s3_client = boto3 .client ("s3" )
4348 abalone_path = os .path .join (os .path .dirname (__file__ ), "data" , "pipeline" , "abalone.csv" )
4449 s3_client .upload_file (abalone_path , bucket , f"{ prefix } /input/abalone.csv" )
4550 input_data_s3 = f"s3://{ bucket } /{ prefix } /input/abalone.csv"
46-
51+
4752 # Parameters
4853 processing_instance_count = ParameterInteger (name = "ProcessingInstanceCount" , default_value = 1 )
54+ training_instance_count = ParameterInteger (name = "TrainingInstanceCount" , default_value = 1 )
55+ instance_type = ParameterString (name = "InstanceType" , default_value = "ml.m5.xlarge" )
4956 input_data = ParameterString (
5057 name = "InputDataUrl" ,
5158 default_value = input_data_s3 ,
5259 )
53-
60+ hyper_parameter_objective = ParameterString (
61+ name = "TrainingObjective" , default_value = "reg:linear"
62+ )
63+
5464 cache_config = CacheConfig (enable_caching = True , expire_after = "30d" )
55-
65+
5666 # Processing step
5767 sklearn_processor = ScriptProcessor (
5868 image_uri = image_uris .retrieve (
@@ -62,13 +72,13 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
6272 py_version = "py3" ,
6373 instance_type = "ml.m5.xlarge" ,
6474 ),
65- instance_type = "ml.m5.xlarge" ,
75+ instance_type = instance_type ,
6676 instance_count = processing_instance_count ,
6777 base_job_name = f"{ base_job_prefix } -sklearn" ,
6878 sagemaker_session = pipeline_session ,
6979 role = role ,
7080 )
71-
81+
7282 processor_args = sklearn_processor .run (
7383 inputs = [
7484 ProcessingInput (
@@ -79,7 +89,7 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
7989 s3_data_type = "S3Prefix" ,
8090 s3_input_mode = "File" ,
8191 s3_data_distribution_type = "ShardedByS3Key" ,
82- )
92+ ),
8393 )
8494 ],
8595 outputs = [
@@ -88,36 +98,36 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
8898 s3_output = ProcessingS3Output (
8999 s3_uri = f"s3://{ sagemaker_session .default_bucket ()} /{ prefix } /train" ,
90100 local_path = "/opt/ml/processing/train" ,
91- s3_upload_mode = "EndOfJob"
92- )
101+ s3_upload_mode = "EndOfJob" ,
102+ ),
93103 ),
94104 ProcessingOutput (
95105 output_name = "validation" ,
96106 s3_output = ProcessingS3Output (
97107 s3_uri = f"s3://{ sagemaker_session .default_bucket ()} /{ prefix } /validation" ,
98108 local_path = "/opt/ml/processing/validation" ,
99- s3_upload_mode = "EndOfJob"
100- )
109+ s3_upload_mode = "EndOfJob" ,
110+ ),
101111 ),
102112 ProcessingOutput (
103113 output_name = "test" ,
104114 s3_output = ProcessingS3Output (
105115 s3_uri = f"s3://{ sagemaker_session .default_bucket ()} /{ prefix } /test" ,
106116 local_path = "/opt/ml/processing/test" ,
107- s3_upload_mode = "EndOfJob"
108- )
117+ s3_upload_mode = "EndOfJob" ,
118+ ),
109119 ),
110120 ],
111121 code = os .path .join (os .path .dirname (__file__ ), "code" , "pipeline" , "preprocess.py" ),
112122 arguments = ["--input-data" , input_data ],
113123 )
114-
124+
115125 step_process = ProcessingStep (
116126 name = "PreprocessData" ,
117127 step_args = processor_args ,
118128 cache_config = cache_config ,
119129 )
120-
130+
121131 # Training step
122132 image_uri = image_uris .retrieve (
123133 framework = "xgboost" ,
@@ -126,47 +136,46 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
126136 py_version = "py3" ,
127137 instance_type = "ml.m5.xlarge" ,
128138 )
129-
139+
130140 model_trainer = ModelTrainer (
131141 training_image = image_uri ,
132- compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 1 ),
142+ compute = Compute (instance_type = instance_type , instance_count = training_instance_count ),
133143 base_job_name = f"{ base_job_prefix } -xgboost" ,
134144 sagemaker_session = pipeline_session ,
135145 role = role ,
136146 hyperparameters = {
137- "objective" : "reg:linear" ,
147+ "objective" : hyper_parameter_objective ,
138148 "num_round" : 50 ,
139149 "max_depth" : 5 ,
140150 },
141151 input_data_config = [
142152 InputData (
143153 channel_name = "train" ,
144- data_source = step_process .properties .ProcessingOutputConfig .Outputs ["train" ].S3Output .S3Uri ,
145- content_type = "text/csv"
154+ data_source = step_process .properties .ProcessingOutputConfig .Outputs [
155+ "train"
156+ ].S3Output .S3Uri ,
157+ content_type = "text/csv" ,
146158 ),
147159 ],
148160 )
149-
161+
150162 train_args = model_trainer .train ()
151163 step_train = TrainingStep (
152164 name = "TrainModel" ,
153165 step_args = train_args ,
154166 cache_config = cache_config ,
155167 )
156-
168+
157169 # Model step
158170 model_builder = ModelBuilder (
159171 s3_model_data_url = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
160172 image_uri = image_uri ,
161173 sagemaker_session = pipeline_session ,
162174 role_arn = role ,
163175 )
164-
165- step_create_model = ModelStep (
166- name = "CreateModel" ,
167- step_args = model_builder .build ()
168- )
169-
176+
177+ step_create_model = ModelStep (name = "CreateModel" , step_args = model_builder .build ())
178+
170179 # Register step
171180 model_package_group_name = f"integ-test-model-group-{ uuid .uuid4 ().hex [:8 ]} "
172181 step_register_model = ModelStep (
@@ -176,33 +185,39 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
176185 content_types = ["application/json" ],
177186 response_types = ["application/json" ],
178187 inference_instances = ["ml.m5.xlarge" ],
179- approval_status = "Approved"
180- )
188+ approval_status = "Approved" ,
189+ ),
181190 )
182-
191+
183192 # Pipeline
184193 pipeline_name = f"integ-test-train-registry-{ uuid .uuid4 ().hex [:8 ]} "
185194 pipeline = Pipeline (
186195 name = pipeline_name ,
187- parameters = [processing_instance_count , input_data ],
196+ parameters = [
197+ processing_instance_count ,
198+ training_instance_count ,
199+ instance_type ,
200+ input_data ,
201+ hyper_parameter_objective ,
202+ ],
188203 steps = [step_process , step_train , step_create_model , step_register_model ],
189204 sagemaker_session = pipeline_session ,
190205 )
191-
206+
192207 model_name = None
193208 try :
194209 # Upsert and execute pipeline
195210 pipeline .upsert (role_arn = role )
196211 execution = pipeline .start ()
197-
212+
198213 # Poll execution status with 30 minute timeout
199214 timeout = 1800
200215 start_time = time .time ()
201-
216+
202217 while time .time () - start_time < timeout :
203218 execution_desc = execution .describe ()
204219 execution_status = execution_desc ["PipelineExecutionStatus" ]
205-
220+
206221 if execution_status == "Succeeded" :
207222 # Get model name from execution steps
208223 steps = sagemaker_session .sagemaker_client .list_pipeline_execution_steps (
@@ -219,41 +234,47 @@ def test_pipeline_with_train_and_registry(sagemaker_session, pipeline_session, r
219234 steps = sagemaker_session .sagemaker_client .list_pipeline_execution_steps (
220235 PipelineExecutionArn = execution_desc ["PipelineExecutionArn" ]
221236 )["PipelineExecutionSteps" ]
222-
237+
223238 failed_steps = []
224239 for step in steps :
225240 if step .get ("StepStatus" ) == "Failed" :
226241 failure_reason = step .get ("FailureReason" , "Unknown reason" )
227242 failed_steps .append (f"{ step ['StepName' ]} : { failure_reason } " )
228-
229- failure_details = "\n " .join (failed_steps ) if failed_steps else "No detailed failure information available"
230- pytest .fail (f"Pipeline execution { execution_status } . Failed steps:\n { failure_details } " )
231-
243+
244+ failure_details = (
245+ "\n " .join (failed_steps )
246+ if failed_steps
247+ else "No detailed failure information available"
248+ )
249+ pytest .fail (
250+ f"Pipeline execution { execution_status } . Failed steps:\n { failure_details } "
251+ )
252+
232253 time .sleep (60 )
233254 else :
234255 pytest .fail (f"Pipeline execution timed out after { timeout } seconds" )
235-
256+
236257 finally :
237258 # Cleanup S3 resources
238- s3 = boto3 .resource ('s3' )
259+ s3 = boto3 .resource ("s3" )
239260 bucket_obj = s3 .Bucket (bucket )
240- bucket_obj .objects .filter (Prefix = f' { prefix } /' ).delete ()
241-
261+ bucket_obj .objects .filter (Prefix = f" { prefix } /" ).delete ()
262+
242263 # Cleanup model
243264 if model_name :
244265 try :
245266 sagemaker_session .sagemaker_client .delete_model (ModelName = model_name )
246267 except Exception :
247268 pass
248-
269+
249270 # Cleanup model package group
250271 try :
251272 sagemaker_session .sagemaker_client .delete_model_package_group (
252273 ModelPackageGroupName = model_package_group_name
253274 )
254275 except Exception :
255276 pass
256-
277+
257278 # Cleanup pipeline
258279 try :
259280 sagemaker_session .sagemaker_client .delete_pipeline (PipelineName = pipeline_name )
0 commit comments