Skip to content

Commit b8255d2

Browse files
committed
support pipelinechannel on set env vars
Signed-off-by: JerT33 <trestjeremiah@gmail.com> compiler updates Signed-off-by: JerT33 <trestjeremiah@gmail.com> address copilot feedback Signed-off-by: JerT33 <trestjeremiah@gmail.com>
1 parent 8db042c commit b8255d2

File tree

5 files changed

+88
-7
lines changed

5 files changed

+88
-7
lines changed

backend/src/v2/driver/driver.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,11 @@ func initPodSpecPatch(
267267
// Convert environment variables
268268
userEnvVar := make([]k8score.EnvVar, 0)
269269
for _, envVar := range container.GetEnv() {
270-
userEnvVar = append(userEnvVar, k8score.EnvVar{Name: envVar.GetName(), Value: envVar.GetValue()})
270+
resolvedValue, err := resolveInputParameterPlaceholders(envVar.GetValue(), executorInput)
271+
if err != nil {
272+
return nil, fmt.Errorf("failed to resolve environment variable %q: %w", envVar.GetName(), err)
273+
}
274+
userEnvVar = append(userEnvVar, k8score.EnvVar{Name: envVar.GetName(), Value: resolvedValue})
271275
}
272276

273277
userEnvVar = append(userEnvVar, proxy.GetConfig().GetEnvVars()...)

sdk/python/kfp/compiler/compiler_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,39 @@ def simple_pipeline():
10811081
'empty-component']
10821082
self.assertTrue('inputs' not in dag_task)
10831083

1084+
def test_pipeline_with_pipeline_channel_env_variable(self):
1085+
with tempfile.TemporaryDirectory() as tmpdir:
1086+
1087+
@dsl.component(base_image='python:3.11')
1088+
def empty_component():
1089+
pass
1090+
1091+
@dsl.pipeline()
1092+
def simple_pipeline(my_var: str):
1093+
task = empty_component()
1094+
task.set_env_variable('MY_VAR', my_var)
1095+
1096+
output_yaml = os.path.join(tmpdir, 'result.yaml')
1097+
compiler.Compiler().compile(
1098+
pipeline_func=simple_pipeline,
1099+
package_path=output_yaml,
1100+
pipeline_parameters={'my_var': 'default'})
1101+
self.assertTrue(os.path.exists(output_yaml))
1102+
1103+
with open(output_yaml, 'r') as f:
1104+
pipeline_spec = yaml.safe_load(f)
1105+
container = pipeline_spec['deploymentSpec']['executors'][
1106+
'exec-empty-component']['container']
1107+
env_vars = {
1108+
e['name']: e['value'] for e in container.get('env', [])
1109+
}
1110+
self.assertEqual(
1111+
env_vars['MY_VAR'],
1112+
"{{$.inputs.parameters['pipelinechannel--my_var']}}")
1113+
input_parameters = pipeline_spec['root']['dag']['tasks'][
1114+
'empty-component']['inputs']['parameters']
1115+
self.assertIn('pipelinechannel--my_var', input_parameters)
1116+
10841117
def test_compile_with_kubernetes_manifest_format(self):
10851118
with tempfile.TemporaryDirectory() as tmpdir:
10861119

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ def build_task_spec_for_task(
145145
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
146146
task.inputs['base_image'] = val
147147

148+
if task.container_spec and task.container_spec.env:
149+
for env_name, env_val in task.container_spec.env.items():
150+
if env_val and pipeline_channel.extract_pipeline_channels_from_any(
151+
env_val):
152+
if env_name in task.inputs:
153+
raise ValueError(
154+
f'Environment variable name "{env_name}" collides '
155+
f'with an existing task input. Please rename either '
156+
f'the input or the environment variable.')
157+
task.inputs[env_name] = env_val
158+
148159
for input_name, input_value in task.inputs.items():
149160
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
150161
# types than PipelineParameterChannel, start with them.
@@ -729,9 +740,7 @@ def convert_to_placeholder(input_value: str) -> str:
729740
compiler injected input name."""
730741
pipeline_channels = (
731742
pipeline_channel.extract_pipeline_channels_from_any(input_value))
732-
if pipeline_channels:
733-
assert len(pipeline_channels) == 1
734-
channel = pipeline_channels[0]
743+
for channel in pipeline_channels:
735744
additional_input_name = (
736745
compiler_utils.additional_input_name_for_pipeline_channel(
737746
channel))
@@ -748,7 +757,7 @@ def convert_to_placeholder(input_value: str) -> str:
748757
args=task.container_spec.args,
749758
env=[
750759
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec
751-
.EnvVar(name=name, value=value)
760+
.EnvVar(name=name, value=convert_to_placeholder(value))
752761
for name, value in (task.container_spec.env or {}).items()
753762
]))
754763

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,22 +623,43 @@ def set_display_name(self, name: str) -> 'PipelineTask':
623623
return self
624624

625625
@block_if_final()
626-
def set_env_variable(self, name: str, value: str) -> 'PipelineTask':
626+
def set_env_variable(
627+
self,
628+
name: str,
629+
value: Union[str, pipeline_channel.PipelineChannel],
630+
) -> 'PipelineTask':
627631
"""Sets environment variable for the task.
628632
629633
Args:
630634
name: Environment variable name.
631-
value: Environment variable value.
635+
value: Environment variable value. Supports dynamic values such as
636+
Pipeline Parameters or outputs from previous tasks, which are
637+
resolved at runtime.
632638
633639
Returns:
634640
Self return to allow chained setting calls.
635641
"""
636642
self._ensure_container_spec_exists()
637643

644+
pipeline_channels = (
645+
pipeline_channel.extract_pipeline_channels_from_any(value))
646+
647+
if isinstance(value, pipeline_channel.PipelineChannel):
648+
value = str(value)
649+
638650
if self.container_spec.env is not None:
639651
self.container_spec.env[name] = value
640652
else:
641653
self.container_spec.env = {name: value}
654+
655+
if pipeline_channels:
656+
existing_channel_patterns = {
657+
channel.pattern for channel in self._channel_inputs
658+
}
659+
for channel in pipeline_channels:
660+
if channel.pattern not in existing_channel_patterns:
661+
self._channel_inputs.append(channel)
662+
existing_channel_patterns.add(channel.pattern)
642663
return self
643664

644665
@block_if_final()

sdk/python/kfp/dsl/pipeline_task_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,20 @@ def test_set_env_variable(self):
366366
task.set_env_variable('env_name', 'env_value')
367367
self.assertEqual({'env_name': 'env_value'}, task.container_spec.env)
368368

369+
def test_set_env_variable_with_pipeline_channel(self):
370+
task = pipeline_task.PipelineTask(
371+
component_spec=structures.ComponentSpec.from_yaml_documents(
372+
V2_YAML),
373+
args={'input1': 'value'},
374+
)
375+
channel = pipeline_channel.PipelineParameterChannel(
376+
name='param', channel_type='String', task_name='upstream')
377+
task.set_env_variable('MY_VAR', channel)
378+
self.assertIn('MY_VAR', task.container_spec.env)
379+
self.assertEqual(str(channel), task.container_spec.env['MY_VAR'])
380+
self.assertEqual(1, len(task._channel_inputs))
381+
self.assertEqual(channel.pattern, task._channel_inputs[0].pattern)
382+
369383
def test_set_display_name(self):
370384
task = pipeline_task.PipelineTask(
371385
component_spec=structures.ComponentSpec.from_yaml_documents(

0 commit comments

Comments
 (0)