diff --git a/kale/processors/nbprocessor.py b/kale/processors/nbprocessor.py index 908b73fd1..a8245692a 100644 --- a/kale/processors/nbprocessor.py +++ b/kale/processors/nbprocessor.py @@ -507,6 +507,13 @@ def parse_cell_metadata(self, metadata): ) parsed_tags["annotations"] = cell_annotations + if cell_labels: + if missing_step_names: + raise ValueError( + "A cell can not provide Pod labels in a cell that does not declare a step name." + ) + parsed_tags["labels"] = cell_labels + if cell_limits: if missing_step_names: raise ValueError( diff --git a/kale/templates/pipeline_template.jinja2 b/kale/templates/pipeline_template.jinja2 index 2247db6a5..8dd83e96b 100644 --- a/kale/templates/pipeline_template.jinja2 +++ b/kale/templates/pipeline_template.jinja2 @@ -1,7 +1,7 @@ import json import kfp.dsl as kfp_dsl from kfp.dsl import Input, Output, Dataset, HTML, Metrics, Artifact, Model -from kfp.kubernetes import security_context +from kfp.kubernetes import add_pod_annotation, add_pod_label, security_context {{ lightweight_components | join('\n\n') }} @@ -65,6 +65,18 @@ def auto_generated_pipeline( {%- endfor %} {%- endif %} + {%- if step.config.labels %} + {%- for label_key, label_value in step.config.labels.items() %} + add_pod_label(task={{ step.name }}_task, label_key="{{ label_key }}", label_value="{{ label_value }}") + {%- endfor %} + {%- endif %} + + {%- if step.config.annotations %} + {%- for ann_key, ann_value in step.config.annotations.items() %} + add_pod_annotation(task={{ step.name }}_task, annotation_key="{{ ann_key }}", annotation_value="{{ ann_value }}") + {%- endfor %} + {%- endif %} + {% endfor %} if __name__ == "__main__": diff --git a/kale/tests/assets/kfp_dsl/iris.py b/kale/tests/assets/kfp_dsl/iris.py index bbb85e053..89a687c3a 100644 --- a/kale/tests/assets/kfp_dsl/iris.py +++ b/kale/tests/assets/kfp_dsl/iris.py @@ -1,7 +1,7 @@ import json import kfp.dsl as kfp_dsl from kfp.dsl import Input, Output, Dataset, HTML, Metrics, Artifact, Model -from kfp.kubernetes import security_context +from kfp.kubernetes import add_pod_annotation, add_pod_label, security_context @kfp_dsl.component( @@ -304,6 +304,8 @@ def auto_generated_pipeline( load_transform_data_task.set_display_name("load-transform-data-step") load_transform_data_task.set_caching_options(enable_caching=False) + add_pod_label(task=load_transform_data_task, + label_key="access-ml-pipeline", label_value="true") train_model_task = train_model_step( x_trn_input_artifact=load_transform_data_task.outputs["x_trn_output_artifact"], @@ -325,6 +327,8 @@ def auto_generated_pipeline( train_model_task.set_display_name("train-model-step") train_model_task.set_caching_options(enable_caching=False) + add_pod_label(task=train_model_task, + label_key="access-ml-pipeline", label_value="true") evaluate_model_task = evaluate_model_step( model_input_artifact=train_model_task.outputs["model_output_artifact"], @@ -348,6 +352,8 @@ def auto_generated_pipeline( evaluate_model_task.set_display_name("evaluate-model-step") evaluate_model_task.set_caching_options(enable_caching=False) + add_pod_label(task=evaluate_model_task, + label_key="access-ml-pipeline", label_value="true") if __name__ == "__main__": diff --git a/kale/tests/assets/kfp_dsl/pipeline_parameters_and_metrics.py b/kale/tests/assets/kfp_dsl/pipeline_parameters_and_metrics.py index a950fd788..0aba17f0f 100644 --- a/kale/tests/assets/kfp_dsl/pipeline_parameters_and_metrics.py +++ b/kale/tests/assets/kfp_dsl/pipeline_parameters_and_metrics.py @@ -1,7 +1,7 @@ import json import kfp.dsl as kfp_dsl from kfp.dsl import Input, Output, Dataset, HTML, Metrics, Artifact, Model -from kfp.kubernetes import security_context +from kfp.kubernetes import add_pod_annotation, add_pod_label, security_context @kfp_dsl.component(