Skip to content

Commit 5cb02ab

Browse files
committed
fix tests
Signed-off-by: typhoonzero <[email protected]>
1 parent c32a1be commit 5cb02ab

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

elyra/pipeline/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def gpu_vendor(self) -> Optional[str]:
336336

337337
@property
338338
def parallel_count(self) -> Optional[str]:
339-
return self._component_props.get("parallel_count")
339+
return self._component_props.get("parallel_count", 1)
340340

341341
def __eq__(self, other: GenericOperation) -> bool:
342342
if isinstance(self, other.__class__):

elyra/templates/kubeflow/v1/python_dsl_template.jinja2

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ def generated_pipeline(
3434
{% set task_name = "task_" + workflow_task.escaped_task_id %}
3535
# Task for node '{{ workflow_task.name }}'
3636
{% set parallel_indent = 0 %}
37-
{% if workflow_task.task_modifiers.parallel_count > 1 %}
37+
{% if 'parallel_count' in workflow_task.task_modifiers and workflow_task.task_modifiers.parallel_count is not none %}
38+
{% if workflow_task.task_modifiers.parallel_count > 1 %}
3839
{% set parallel_indent = 4 %}
3940
parallel_count = {{workflow_task.task_modifiers.parallel_count}}
4041
with kfp.dsl.ParallelFor(list(range(parallel_count))) as rank:
42+
{% endif %}
4143
{% endif %}
4244

4345
{% filter indent(width=parallel_indent) %}
@@ -81,9 +83,11 @@ def generated_pipeline(
8183
{% for env_var_name, env_var_value in workflow_task.task_modifiers.env_variables.items() %}
8284
{{ task_name }}.add_env_variable(V1EnvVar(name="{{ env_var_name }}", value="{{ env_var_value | string_delimiter_safe }}"))
8385
{% endfor %}
84-
{% if workflow_task.task_modifiers.parallel_count > 1 %}
86+
{% if 'parallel_count' in workflow_task.task_modifiers and workflow_task.task_modifiers.parallel_count is not none %}
87+
{% if workflow_task.task_modifiers.parallel_count > 1 %}
8588
{{ task_name }}.add_env_variable(V1EnvVar(name="NRANKS", value=str(parallel_count)))
8689
{{ task_name }}.add_env_variable(V1EnvVar(name="RANK", value=str(rank)))
90+
{% endif %}
8791
{% endif %}
8892
{% if workflow_engine == "argo" %}
8993
{{ task_name }}.add_env_variable(V1EnvVar(

elyra/tests/pipeline/kfp/test_processor_kfp.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def test_generate_pipeline_dsl_compile_pipeline_dsl_one_generic_node_pipeline_te
735735

736736
# Verify component definition information (see generic_component_definition_template.jinja2)
737737
# - property 'name'
738-
assert node_template["name"] == "run-a-file"
738+
assert node_template["name"] == sanitize_label_value(op.name)
739739
# - property 'implementation.container.command'
740740
assert node_template["container"]["command"] == ["sh", "-c"]
741741
# - property 'implementation.container.args'
@@ -1416,11 +1416,9 @@ def test_generate_pipeline_dsl_compile_pipeline_dsl_generic_components_data_exch
14161416
assert len(compiled_spec["spec"]["templates"]) >= 3
14171417
template_specs = {}
14181418
for node_template in compiled_spec["spec"]["templates"]:
1419-
if node_template["name"] == compiled_spec["spec"]["entrypoint"] or not node_template["name"].startswith(
1420-
"run-a-file"
1421-
):
1419+
if node_template["name"] == compiled_spec["spec"]["entrypoint"]:
14221420
continue
1423-
template_specs[node_template["name"]] = node_template
1421+
template_specs[sanitize_label_value(node_template["name"])] = node_template
14241422

14251423
# Iterate through sorted operations and verify that their inputs
14261424
# and outputs are properly represented in their respective template
@@ -1430,10 +1428,8 @@ def test_generate_pipeline_dsl_compile_pipeline_dsl_generic_components_data_exch
14301428
if not op.is_generic:
14311429
# ignore custom nodes
14321430
continue
1433-
if template_index == 1:
1434-
template_name = "run-a-file"
1435-
else:
1436-
template_name = f"run-a-file-{template_index}"
1431+
template_name = sanitize_label_value(op.name)
1432+
template_name = template_name.replace("_", "-") # kubernetes does this replace
14371433
template_index = template_index + 1
14381434
# compare outputs
14391435
if len(op.outputs) > 0:

0 commit comments

Comments
 (0)