@@ -240,6 +240,39 @@ def image_builder(buildspec, image_types=[], device_types=[]):
240240 }
241241 }
242242 )
243+ # job_type will be either inference or training, based on the repo URI
244+ if "training" in image_repo_uri :
245+ label_job_type = "training"
246+ elif "inference" in image_repo_uri :
247+ label_job_type = "inference"
248+ else :
249+ raise RuntimeError (
250+ f"Cannot find inference or training job type in { image_repo_uri } . "
251+ f"This is required to set job_type label."
252+ )
253+
254+ template_file = os .path .join (
255+ os .sep , get_cloned_folder_path (), "miscellaneous_scripts" , "dlc_template.py"
256+ )
257+
258+ template_fw_version = (
259+ str (image_config ["framework_version" ])
260+ if image_config .get ("framework_version" )
261+ else str (BUILDSPEC ["version" ])
262+ )
263+ template_fw = str (BUILDSPEC ["framework" ])
264+ post_template_file = utils .generate_dlc_cmd (
265+ template_path = template_file ,
266+ output_path = os .path .join (image_config ["root" ], "out.py" ),
267+ framework = template_fw ,
268+ framework_version = template_fw_version ,
269+ container_type = label_job_type ,
270+ )
271+
272+ ARTIFACTS .update (
273+ {"customize" : {"source" : post_template_file , "target" : "sitecustomize.py" }}
274+ )
275+
243276 context = Context (ARTIFACTS , f"build/{ image_name } .tar.gz" , image_config ["root" ])
244277
245278 if "labels" in image_config :
@@ -265,17 +298,6 @@ def image_builder(buildspec, image_types=[], device_types=[]):
265298 label_contributor = str (BUILDSPEC .get ("contributor" ))
266299 label_transformers_version = str (transformers_version ).replace ("." , "-" )
267300
268- # job_type will be either inference or training, based on the repo URI
269- if "training" in image_repo_uri :
270- label_job_type = "training"
271- elif "inference" in image_repo_uri :
272- label_job_type = "inference"
273- else :
274- raise RuntimeError (
275- f"Cannot find inference or training job type in { image_repo_uri } . "
276- f"This is required to set job_type label."
277- )
278-
279301 if cx_type == "sagemaker" :
280302 # Adding standard labels to all images
281303 labels [
0 commit comments