Skip to content

Commit b048202

Browse files
committed
support pipelinechannel on set env vars
Signed-off-by: JerT33 <trestjeremiah@gmail.com> compiler updates Signed-off-by: JerT33 <trestjeremiah@gmail.com>
1 parent 8db042c commit b048202

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

backend/src/v2/driver/driver.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,11 @@ func initPodSpecPatch(
267267
// Convert environment variables
268268
userEnvVar := make([]k8score.EnvVar, 0)
269269
for _, envVar := range container.GetEnv() {
270-
userEnvVar = append(userEnvVar, k8score.EnvVar{Name: envVar.GetName(), Value: envVar.GetValue()})
270+
resolvedValue, err := resolvePodSpecInputRuntimeParameter(envVar.GetValue(), executorInput)
271+
if err != nil {
272+
return nil, fmt.Errorf("failed to resolve environment variable %q: %w", envVar.GetName(), err)
273+
}
274+
userEnvVar = append(userEnvVar, k8score.EnvVar{Name: envVar.GetName(), Value: resolvedValue})
271275
}
272276

273277
userEnvVar = append(userEnvVar, proxy.GetConfig().GetEnvVars()...)

sdk/python/kfp/compiler/pipeline_spec_builder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ def build_task_spec_for_task(
145145
if val and pipeline_channel.extract_pipeline_channels_from_any(val):
146146
task.inputs['base_image'] = val
147147

148+
if task.container_spec and task.container_spec.env:
149+
for env_name, env_val in task.container_spec.env.items():
150+
if env_val and pipeline_channel.extract_pipeline_channels_from_any(
151+
env_val):
152+
task.inputs[env_name] = env_val
153+
148154
for input_name, input_value in task.inputs.items():
149155
# Since LoopParameterArgument and LoopArtifactArgument and LoopArgumentVariable are narrower
150156
# types than PipelineParameterChannel, start with them.
@@ -748,7 +754,7 @@ def convert_to_placeholder(input_value: str) -> str:
748754
args=task.container_spec.args,
749755
env=[
750756
pipeline_spec_pb2.PipelineDeploymentConfig.PipelineContainerSpec
751-
.EnvVar(name=name, value=value)
757+
.EnvVar(name=name, value=convert_to_placeholder(value))
752758
for name, value in (task.container_spec.env or {}).items()
753759
]))
754760

sdk/python/kfp/dsl/pipeline_task.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,22 +623,43 @@ def set_display_name(self, name: str) -> 'PipelineTask':
623623
return self
624624

625625
@block_if_final()
626-
def set_env_variable(self, name: str, value: str) -> 'PipelineTask':
626+
def set_env_variable(
627+
self,
628+
name: str,
629+
value: Union[str, pipeline_channel.PipelineChannel],
630+
) -> 'PipelineTask':
627631
"""Sets environment variable for the task.
628632
629633
Args:
630634
name: Environment variable name.
631-
value: Environment variable value.
635+
value: Environment variable value. Supports dynamic values such as
636+
Pipeline Parameters or outputs from previous tasks, which are
637+
resolved at runtime.
632638
633639
Returns:
634640
Self return to allow chained setting calls.
635641
"""
636642
self._ensure_container_spec_exists()
637643

644+
pipeline_channels = (
645+
pipeline_channel.extract_pipeline_channels_from_any(value))
646+
647+
if isinstance(value, pipeline_channel.PipelineChannel):
648+
value = str(value)
649+
638650
if self.container_spec.env is not None:
639651
self.container_spec.env[name] = value
640652
else:
641653
self.container_spec.env = {name: value}
654+
655+
if pipeline_channels:
656+
existing_channel_patterns = {
657+
channel.pattern for channel in self._channel_inputs
658+
}
659+
for channel in pipeline_channels:
660+
if channel.pattern not in existing_channel_patterns:
661+
self._channel_inputs.append(channel)
662+
existing_channel_patterns.add(channel.pattern)
642663
return self
643664

644665
@block_if_final()

0 commit comments

Comments
 (0)