Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions kale/processors/nbprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,13 @@ def parse_cell_metadata(self, metadata):
)
parsed_tags["annotations"] = cell_annotations

if cell_labels:
if missing_step_names:
raise ValueError(
"A cell can not provide Pod labels in a cell that does not declare a step name."
)
parsed_tags["labels"] = cell_labels

if cell_limits:
if missing_step_names:
raise ValueError(
Expand Down
14 changes: 13 additions & 1 deletion kale/templates/pipeline_template.jinja2
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import kfp.dsl as kfp_dsl
from kfp.dsl import Input, Output, Dataset, HTML, Metrics, Artifact, Model
from kfp.kubernetes import security_context
from kfp.kubernetes import add_pod_annotation, add_pod_label, security_context

{{ lightweight_components | join('\n\n') }}

Expand Down Expand Up @@ -65,6 +65,18 @@ def auto_generated_pipeline(
{%- endfor %}
{%- endif %}

{%- if step.config.labels %}
{%- for label_key, label_value in step.config.labels.items() %}
add_pod_label(task={{ step.name }}_task, label_key="{{ label_key }}", label_value="{{ label_value }}")
{%- endfor %}
{%- endif %}

{%- if step.config.annotations %}
{%- for ann_key, ann_value in step.config.annotations.items() %}
add_pod_annotation(task={{ step.name }}_task, annotation_key="{{ ann_key }}", annotation_value="{{ ann_value }}")
{%- endfor %}
{%- endif %}

{% endfor %}

if __name__ == "__main__":
Expand Down
8 changes: 7 additions & 1 deletion kale/tests/assets/kfp_dsl/iris.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import kfp.dsl as kfp_dsl
from kfp.dsl import Input, Output, Dataset, HTML, Metrics, Artifact, Model
from kfp.kubernetes import security_context
from kfp.kubernetes import add_pod_annotation, add_pod_label, security_context


@kfp_dsl.component(
Expand Down Expand Up @@ -304,6 +304,8 @@ def auto_generated_pipeline(

load_transform_data_task.set_display_name("load-transform-data-step")
load_transform_data_task.set_caching_options(enable_caching=False)
add_pod_label(task=load_transform_data_task,
label_key="access-ml-pipeline", label_value="true")

train_model_task = train_model_step(
x_trn_input_artifact=load_transform_data_task.outputs["x_trn_output_artifact"],
Expand All @@ -325,6 +327,8 @@ def auto_generated_pipeline(

train_model_task.set_display_name("train-model-step")
train_model_task.set_caching_options(enable_caching=False)
add_pod_label(task=train_model_task,
label_key="access-ml-pipeline", label_value="true")

evaluate_model_task = evaluate_model_step(
model_input_artifact=train_model_task.outputs["model_output_artifact"],
Expand All @@ -348,6 +352,8 @@ def auto_generated_pipeline(

evaluate_model_task.set_display_name("evaluate-model-step")
evaluate_model_task.set_caching_options(enable_caching=False)
add_pod_label(task=evaluate_model_task,
label_key="access-ml-pipeline", label_value="true")


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import kfp.dsl as kfp_dsl
from kfp.dsl import Input, Output, Dataset, HTML, Metrics, Artifact, Model
from kfp.kubernetes import security_context
from kfp.kubernetes import add_pod_annotation, add_pod_label, security_context


@kfp_dsl.component(
Expand Down
Loading