Skip to content

Commit 8e436b9

Browse files
committed
support pipelinechannel on set env vars
Signed-off-by: JerT33 <trestjeremiah@gmail.com> compiler updates Signed-off-by: JerT33 <trestjeremiah@gmail.com> address copilot feedback Signed-off-by: JerT33 <trestjeremiah@gmail.com> fix tests Signed-off-by: JerT33 <trestjeremiah@gmail.com> fix formatting Signed-off-by: JerT33 <trestjeremiah@gmail.com>
1 parent 8db042c commit 8e436b9

File tree

5 files changed

+130
-7
lines changed

5 files changed

+130
-7
lines changed

backend/src/v2/driver/driver.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,11 @@ func initPodSpecPatch(
267267
// Convert environment variables
268268
userEnvVar := make([]k8score.EnvVar, 0)
269269
for _, envVar := range container.GetEnv() {
270-
userEnvVar = append(userEnvVar, k8score.EnvVar{Name: envVar.GetName(), Value: envVar.GetValue()})
270+
resolvedValue, err := resolveInputParameterPlaceholders(envVar.GetValue(), executorInput)
271+
if err != nil {
272+
return nil, fmt.Errorf("failed to resolve environment variable %q: %w", envVar.GetName(), err)
273+
}
274+
userEnvVar = append(userEnvVar, k8score.EnvVar{Name: envVar.GetName(), Value: resolvedValue})
271275
}
272276

273277
userEnvVar = append(userEnvVar, proxy.GetConfig().GetEnvVars()...)

sdk/python/kfp/compiler/compiler_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,80 @@ def simple_pipeline():
10811081
'empty-component']
10821082
self.assertTrue('inputs' not in dag_task)
10831083

1084+
def test_pipeline_with_pipeline_channel_env_variable(self):
1085+
with tempfile.TemporaryDirectory() as tmpdir:
1086+
1087+
@dsl.component(base_image='python:3.11')
1088+
def empty_component():
1089+
pass
1090+
1091+
@dsl.pipeline()
1092+
def simple_pipeline(my_var: str):
1093+
task = empty_component()
1094+
task.set_env_variable('MY_VAR', my_var)
1095+
1096+
output_yaml = os.path.join(tmpdir, 'result.yaml')
1097+
compiler.Compiler().compile(
1098+
pipeline_func=simple_pipeline,
1099+
package_path=output_yaml,
1100+
pipeline_parameters={'my_var': 'default'})
1101+
self.assertTrue(os.path.exists(output_yaml))
1102+
1103+
with open(output_yaml, 'r') as f:
1104+
pipeline_spec = yaml.safe_load(f)
1105+
container = pipeline_spec['deploymentSpec']['executors'][
1106+
'exec-empty-component']['container']
1107+
env_vars = {
1108+
e['name']: e['value'] for e in container.get('env', [])
1109+
}
1110+
self.assertEqual(
1111+
env_vars['MY_VAR'],
1112+
"{{$.inputs.parameters['pipelinechannel--my_var']}}")
1113+
input_parameters = pipeline_spec['root']['dag']['tasks'][
1114+
'empty-component']['inputs']['parameters']
1115+
self.assertIn('pipelinechannel--my_var', input_parameters)
1116+
1117+
def test_pipeline_with_task_output_env_variable(self):
1118+
with tempfile.TemporaryDirectory() as tmpdir:
1119+
1120+
@dsl.component(base_image='python:3.11')
1121+
def producer() -> str:
1122+
return 'hello'
1123+
1124+
@dsl.component(base_image='python:3.11')
1125+
def consumer():
1126+
pass
1127+
1128+
@dsl.pipeline()
1129+
def task_output_pipeline():
1130+
prod_task = producer()
1131+
cons_task = consumer()
1132+
cons_task.set_env_variable('MY_VAR', prod_task.output)
1133+
1134+
output_yaml = os.path.join(tmpdir, 'result.yaml')
1135+
compiler.Compiler().compile(
1136+
pipeline_func=task_output_pipeline, package_path=output_yaml)
1137+
self.assertTrue(os.path.exists(output_yaml))
1138+
1139+
with open(output_yaml, 'r') as f:
1140+
pipeline_spec = yaml.safe_load(f)
1141+
container = pipeline_spec['deploymentSpec']['executors'][
1142+
'exec-consumer']['container']
1143+
env_vars = {
1144+
e['name']: e['value'] for e in container.get('env', [])
1145+
}
1146+
self.assertEqual(
1147+
env_vars['MY_VAR'],
1148+
"{{$.inputs.parameters['pipelinechannel--producer-Output']}}"
1149+
)
1150+
input_parameters = pipeline_spec['root']['dag']['tasks'][
1151+
'consumer']['inputs']['parameters']
1152+
self.assertIn('pipelinechannel--producer-Output',
1153+
input_parameters)
1154+
self.assertEqual(
1155+
input_parameters['pipelinechannel--producer-Output']
1156+
['taskOutputParameter']['producerTask'], 'producer')
1157+
10841158
def test_compile_with_kubernetes_manifest_format(self):
10851159
with tempfile.TemporaryDirectory() as tmpdir:
10861160

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ def build_task_spec_for_task(
145145
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
146146
task.inputs['base_image'] = val
147147

148+
if task.container_spec and task.container_spec.env:
149+
for env_name, env_val in task.container_spec.env.items():
150+
if env_val and pipeline_channel.extract_pipeline_channels_from_any(
151+
env_val):
152+
if env_name in task.inputs:
153+
raise ValueError(
154+
f'Environment variable name "{env_name}" collides '
155+
f'with an existing task input. Please rename either '
156+
f'the input or the environment variable.')
157+
task.inputs[env_name] = env_val
158+
148159
for input_name, input_value in task.inputs.items():
149160
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
150161
# types than PipelineParameterChannel, start with them.
@@ -729,9 +740,7 @@ def convert_to_placeholder(input_value: str) -> str:
729740
compiler injected input name."""
730741
pipeline_channels = (
731742
pipeline_channel.extract_pipeline_channels_from_any(input_value))
732-
if pipeline_channels:
733-
assert len(pipeline_channels) == 1
734-
channel = pipeline_channels[0]
743+
for channel in pipeline_channels:
735744
additional_input_name = (
736745
compiler_utils.additional_input_name_for_pipeline_channel(
737746
channel))
@@ -748,7 +757,7 @@ def convert_to_placeholder(input_value: str) -> str:
748757
args=task.container_spec.args,
749758
env=[
750759
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec
751-
.EnvVar(name=name, value=value)
760+
.EnvVar(name=name, value=convert_to_placeholder(value))
752761
for name, value in (task.container_spec.env or {}).items()
753762
]))
754763

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,22 +623,43 @@ def set_display_name(self, name: str) -> 'PipelineTask':
623623
return self
624624

625625
@block_if_final()
626-
def set_env_variable(self, name: str, value: str) -> 'PipelineTask':
626+
def set_env_variable(
627+
self,
628+
name: str,
629+
value: Union[str, pipeline_channel.PipelineChannel],
630+
) -> 'PipelineTask':
627631
"""Sets environment variable for the task.
628632
629633
Args:
630634
name: Environment variable name.
631-
value: Environment variable value.
635+
value: Environment variable value. Supports dynamic values such as
636+
Pipeline Parameters or outputs from previous tasks, which are
637+
resolved at runtime.
632638
633639
Returns:
634640
Self return to allow chained setting calls.
635641
"""
636642
self._ensure_container_spec_exists()
637643

644+
pipeline_channels = (
645+
pipeline_channel.extract_pipeline_channels_from_any(value))
646+
647+
if isinstance(value, pipeline_channel.PipelineChannel):
648+
value = str(value)
649+
638650
if self.container_spec.env is not None:
639651
self.container_spec.env[name] = value
640652
else:
641653
self.container_spec.env = {name: value}
654+
655+
if pipeline_channels:
656+
existing_channel_patterns = {
657+
channel.pattern for channel in self._channel_inputs
658+
}
659+
for channel in pipeline_channels:
660+
if channel.pattern not in existing_channel_patterns:
661+
self._channel_inputs.append(channel)
662+
existing_channel_patterns.add(channel.pattern)
642663
return self
643664

644665
@block_if_final()

sdk/python/kfp/dsl/pipeline_task_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from absl.testing import parameterized
2020
from kfp import dsl
21+
from kfp.dsl import pipeline_channel
2122
from kfp.dsl import pipeline_task
2223
from kfp.dsl import placeholders
2324
from kfp.dsl import structures
@@ -366,6 +367,20 @@ def test_set_env_variable(self):
366367
task.set_env_variable('env_name', 'env_value')
367368
self.assertEqual({'env_name': 'env_value'}, task.container_spec.env)
368369

370+
def test_set_env_variable_with_pipeline_channel(self):
371+
task = pipeline_task.PipelineTask(
372+
component_spec=structures.ComponentSpec.from_yaml_documents(
373+
V2_YAML),
374+
args={'input1': 'value'},
375+
)
376+
channel = pipeline_channel.PipelineParameterChannel(
377+
name='param', channel_type='String', task_name='upstream')
378+
task.set_env_variable('MY_VAR', channel)
379+
self.assertIn('MY_VAR', task.container_spec.env)
380+
self.assertEqual(str(channel), task.container_spec.env['MY_VAR'])
381+
self.assertEqual(1, len(task._channel_inputs))
382+
self.assertEqual(channel.pattern, task._channel_inputs[0].pattern)
383+
369384
def test_set_display_name(self):
370385
task = pipeline_task.PipelineTask(
371386
component_spec=structures.ComponentSpec.from_yaml_documents(

0 commit comments

Comments
 (0)