Skip to content

Commit b83a935

Browse files
committed
fix(runtimes): propagate trainer environment variables to worker processes
Address issue #3427 by: 1. Ensuring environment variables are injected into MPI worker pods regardless of runLauncherAsNode setting. 2. Automatically populating OMPI_MCA_mca_base_env_list on the launcher to propagate variables across the SSH boundary. 3. Filtering out pod-specific environment variables (ValueFrom) during propagation. 4. Adding comprehensive unit tests for these scenarios. Closes #3427 Signed-off-by: AviralKaushal <aviralkaush@gmail.com>
1 parent 7dfaa64 commit b83a935

5 files changed

Lines changed: 331 additions & 1 deletion

File tree

pkg/constants/constants.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ const (
176176

177177
// OpenMPIEnvDefaultSlots is the OpenMPI default number of slots env key.
178178
OpenMPIEnvDefaultSlots string = "OMPI_MCA_orte_set_default_slots"
179+
180+
// OpenMPIEnvBaseEnvList is the OpenMPI base environment list env key.
181+
OpenMPIEnvBaseEnvList string = "OMPI_MCA_mca_base_env_list"
179182
// Distributed envs for torchrun.
180183
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
181184
// TorchEnvNumNodes is the env name for the number of training nodes.

pkg/runtime/framework/plugins/jobset/builder.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build
139139
}
140140
}
141141
}
142-
if ancestor == constants.AncestorTrainer || b.isRunLauncherAsNode(info) && *rJob.Name == constants.Node {
142+
if ancestor == constants.AncestorTrainer || *rJob.Name == constants.Node {
143143
// TODO (andreyvelich): For MPI we should apply container resources to the Node ReplicatedJob also.
144144
// Eventually, we should find better way to propagate resources from TrainJob to JobSet.
145145
for j, container := range rJob.Template.Spec.Template.Spec.Containers {
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
Copyright 2025 The Kubeflow Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package jobset
18+
19+
import (
20+
"testing"
21+
22+
corev1 "k8s.io/api/core/v1"
23+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
24+
batchv1ac "k8s.io/client-go/applyconfigurations/batch/v1"
25+
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
26+
"k8s.io/utils/ptr"
27+
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"
28+
29+
trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
30+
"github.com/kubeflow/trainer/v2/pkg/constants"
31+
"github.com/kubeflow/trainer/v2/pkg/runtime"
32+
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
33+
)
34+
35+
func TestBuilderTrainerEnvPropagation(t *testing.T) {
36+
testCases := map[string]struct {
37+
trainJobEnv []corev1.EnvVar
38+
initialPodEnv []corev1ac.EnvVarApplyConfiguration
39+
expectedEnv []corev1.EnvVar
40+
}{
41+
"variables propagated": {
42+
trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}},
43+
expectedEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}},
44+
},
45+
"no variables propagated (empty case)": {
46+
trainJobEnv: []corev1.EnvVar{},
47+
expectedEnv: []corev1.EnvVar{},
48+
},
49+
"merge with existing variables": {
50+
trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}},
51+
initialPodEnv: []corev1ac.EnvVarApplyConfiguration{
52+
*corev1ac.EnvVar().WithName("EXISTING_VAR").WithValue("existing_value"),
53+
},
54+
expectedEnv: []corev1.EnvVar{
55+
{Name: "EXISTING_VAR", Value: "existing_value"},
56+
{Name: "CUSTOM_VAR", Value: "custom_value"},
57+
},
58+
},
59+
}
60+
61+
for name, tc := range testCases {
62+
t.Run(name, func(t *testing.T) {
63+
// Setup TrainJob
64+
trainJob := utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
65+
Trainer(utiltesting.MakeTrainJobTrainerWrapper().
66+
Env(tc.trainJobEnv...).
67+
Obj()).
68+
Obj()
69+
70+
// Setup runtime info for MPI (launcher is NOT a node)
71+
info := &runtime.Info{
72+
RuntimePolicy: runtime.RuntimePolicy{
73+
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
74+
MPIPolicy(nil, trainer.MPIImplementationOpenMPI, nil, ptr.To(false)).
75+
Obj(),
76+
},
77+
}
78+
79+
// Create JobSet spec with initial environment
80+
container := corev1ac.Container().WithName(constants.Node)
81+
for i := range tc.initialPodEnv {
82+
container.WithEnv(&tc.initialPodEnv[i])
83+
}
84+
85+
jobSetSpec := jobsetv1alpha2ac.JobSetSpec().WithReplicatedJobs(
86+
jobsetv1alpha2ac.ReplicatedJob().
87+
WithName(constants.Node).
88+
WithTemplate(batchv1ac.JobTemplateSpec().
89+
WithSpec(batchv1ac.JobSpec().
90+
WithTemplate(corev1ac.PodTemplateSpec().
91+
WithSpec(corev1ac.PodSpec().WithContainers(container)),
92+
),
93+
),
94+
),
95+
)
96+
97+
builder := NewBuilder(jobsetv1alpha2ac.JobSet("test-job", metav1.NamespaceDefault).WithSpec(jobSetSpec))
98+
builder.Trainer(info, trainJob)
99+
100+
// Verify results
101+
var actualEnv []corev1.EnvVar
102+
for _, rJob := range builder.Spec.ReplicatedJobs {
103+
if *rJob.Name == constants.Node {
104+
for _, c := range rJob.Template.Spec.Template.Spec.Containers {
105+
if *c.Name == constants.Node {
106+
for _, env := range c.Env {
107+
actualEnv = append(actualEnv, corev1.EnvVar{Name: *env.Name, Value: *env.Value})
108+
}
109+
}
110+
}
111+
}
112+
}
113+
114+
if len(actualEnv) != len(tc.expectedEnv) {
115+
t.Fatalf("Expected %d environment variables, got %d", len(tc.expectedEnv), len(actualEnv))
116+
}
117+
118+
for i, expected := range tc.expectedEnv {
119+
if actualEnv[i].Name != expected.Name || actualEnv[i].Value != expected.Value {
120+
t.Errorf("At index %d: expected %s=%s, got %s=%s", i, expected.Name, expected.Value, actualEnv[i].Name, actualEnv[i].Value)
121+
}
122+
}
123+
})
124+
}
125+
}

pkg/runtime/framework/plugins/mpi/mpi.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,19 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er
190190
)
191191
switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation {
192192
case trainer.MPIImplementationOpenMPI:
193+
// Collect all custom environment variable names from the TrainJob to export via SSH.
194+
var envNames []string
195+
if trainJob.Spec.Trainer != nil {
196+
for _, env := range trainJob.Spec.Trainer.Env {
197+
// Only include variables with a static Value.
198+
// Variables with ValueFrom (e.g. FieldRef, ResourceFieldRef) are pod-specific
199+
// and should not be propagated as cluster-wide constants.
200+
if env.ValueFrom == nil {
201+
envNames = append(envNames, env.Name)
202+
}
203+
}
204+
}
205+
193206
apply.UpsertEnvVars(
194207
&info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env,
195208
*corev1ac.EnvVar().
@@ -205,6 +218,23 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er
205218
WithName(constants.OpenMPIEnvKeyRSHArgs).
206219
WithValue(constants.OpenMPIEnvDefaultValueRSHArgs),
207220
)
221+
222+
// Automatically tell OpenMPI to export the custom variables to all nodes.
223+
if len(envNames) > 0 {
224+
envList := ""
225+
for i, name := range envNames {
226+
if i > 0 {
227+
envList += ";"
228+
}
229+
envList += name
230+
}
231+
apply.UpsertEnvVars(
232+
&info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env,
233+
*corev1ac.EnvVar().
234+
WithName(constants.OpenMPIEnvBaseEnvList).
235+
WithValue(envList),
236+
)
237+
}
208238
default:
209239
return fmt.Errorf("MPI implementation for %v doesn't supported", info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation)
210240
}

pkg/runtime/framework/plugins/mpi/mpi_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,178 @@ trainJob-node-1-0.trainJob slots=1
829829
},
830830
wantBuildError: errorGetSSHAuthSecretFromAPI,
831831
},
832+
"environment variables propagation and filtering for OpenMPI": {
833+
trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "trainJob").
834+
UID("trainJob").
835+
Trainer(utiltesting.MakeTrainJobTrainerWrapper().
836+
Env(
837+
corev1.EnvVar{Name: "STATIC_VAR", Value: "static_value"},
838+
corev1.EnvVar{Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}},
839+
).
840+
Obj()).
841+
Obj(),
842+
info: &runtime.Info{
843+
RuntimePolicy: runtime.RuntimePolicy{
844+
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
845+
MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil).
846+
Obj(),
847+
},
848+
TemplateSpec: runtime.TemplateSpec{
849+
PodSets: []runtime.PodSet{
850+
{
851+
Name: constants.Launcher,
852+
Count: ptr.To[int32](1),
853+
Endpoints: func(yield func(string) bool) {
854+
yield("trainJob-launcher-0-0.trainJob")
855+
},
856+
Containers: []runtime.Container{{
857+
Name: constants.Node,
858+
}},
859+
},
860+
{
861+
Name: constants.Node,
862+
Count: ptr.To[int32](1),
863+
Endpoints: func(yield func(string) bool) {
864+
yield("trainJob-node-0-0.trainJob")
865+
},
866+
Containers: []runtime.Container{{
867+
Name: constants.Node,
868+
}},
869+
},
870+
},
871+
},
872+
Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)},
873+
},
874+
wantInfo: &runtime.Info{
875+
Labels: nil,
876+
Annotations: nil,
877+
RuntimePolicy: runtime.RuntimePolicy{
878+
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
879+
MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil).
880+
Obj(),
881+
},
882+
TemplateSpec: runtime.TemplateSpec{
883+
PodSets: []runtime.PodSet{
884+
{
885+
Name: constants.Launcher,
886+
Count: ptr.To[int32](1),
887+
Containers: []runtime.Container{{
888+
Name: constants.Node,
889+
Env: []corev1ac.EnvVarApplyConfiguration{
890+
*corev1ac.EnvVar().
891+
WithName(constants.OpenMPIEnvHostFileLocation).
892+
WithValue(fmt.Sprintf("%s/%s", constants.MPIHostfileDir, constants.MPIHostfileName)),
893+
*corev1ac.EnvVar().
894+
WithName(constants.OpenMPIEnvKeepFQDNHostNames).
895+
WithValue("true"),
896+
*corev1ac.EnvVar().
897+
WithName(constants.OpenMPIEnvDefaultSlots).
898+
WithValue("1"),
899+
*corev1ac.EnvVar().
900+
WithName(constants.OpenMPIEnvKeyRSHArgs).
901+
WithValue(constants.OpenMPIEnvDefaultValueRSHArgs),
902+
*corev1ac.EnvVar().
903+
WithName(constants.OpenMPIEnvBaseEnvList).
904+
WithValue("STATIC_VAR"),
905+
},
906+
VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{
907+
*corev1ac.VolumeMount().
908+
WithName(constants.MPISSHAuthVolumeName).
909+
WithMountPath("/root/.ssh"),
910+
*corev1ac.VolumeMount().
911+
WithName(constants.MPIHostfileVolumeName).
912+
WithMountPath(constants.MPIHostfileDir),
913+
},
914+
}},
915+
Volumes: []corev1ac.VolumeApplyConfiguration{
916+
*corev1ac.Volume().
917+
WithName(constants.MPISSHAuthVolumeName).
918+
WithSecret(corev1ac.SecretVolumeSource().
919+
WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)).
920+
WithItems(
921+
corev1ac.KeyToPath().
922+
WithKey(corev1.SSHAuthPrivateKey).
923+
WithPath(constants.MPISSHPrivateKeyFile),
924+
corev1ac.KeyToPath().
925+
WithKey(constants.MPISSHPublicKey).
926+
WithPath(constants.MPISSHPublicKeyFile),
927+
corev1ac.KeyToPath().
928+
WithKey(constants.MPISSHPublicKey).
929+
WithPath(constants.MPISSHAuthorizedKeys),
930+
),
931+
),
932+
*corev1ac.Volume().
933+
WithName(constants.MPIHostfileVolumeName).
934+
WithConfigMap(corev1ac.ConfigMapVolumeSource().
935+
WithName(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix)).
936+
WithItems(
937+
corev1ac.KeyToPath().
938+
WithKey(constants.MPIHostfileName).
939+
WithPath(constants.MPIHostfileName).
940+
WithMode(0444),
941+
),
942+
),
943+
},
944+
Endpoints: func(yield func(string) bool) {
945+
yield("trainJob-launcher-0-0.trainJob")
946+
},
947+
},
948+
{
949+
Name: constants.Node,
950+
Count: ptr.To[int32](1),
951+
Containers: []runtime.Container{{
952+
Name: constants.Node,
953+
VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{
954+
*corev1ac.VolumeMount().
955+
WithName(constants.MPISSHAuthVolumeName).
956+
WithMountPath("/root/.ssh"),
957+
},
958+
}},
959+
Volumes: []corev1ac.VolumeApplyConfiguration{
960+
*corev1ac.Volume().
961+
WithName(constants.MPISSHAuthVolumeName).
962+
WithSecret(corev1ac.SecretVolumeSource().
963+
WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)).
964+
WithItems(
965+
corev1ac.KeyToPath().
966+
WithKey(corev1.SSHAuthPrivateKey).
967+
WithPath(constants.MPISSHPrivateKeyFile),
968+
corev1ac.KeyToPath().
969+
WithKey(constants.MPISSHPublicKey).
970+
WithPath(constants.MPISSHPublicKeyFile),
971+
corev1ac.KeyToPath().
972+
WithKey(constants.MPISSHPublicKey).
973+
WithPath(constants.MPISSHAuthorizedKeys),
974+
),
975+
),
976+
},
977+
Endpoints: func(yield func(string) bool) {
978+
yield("trainJob-node-0-0.trainJob")
979+
},
980+
},
981+
},
982+
},
983+
Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)},
984+
},
985+
wantObjs: []apiruntime.Object{
986+
utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
987+
WithImmutable(true).
988+
WithType(corev1.SecretTypeSSHAuth).
989+
WithData(map[string][]byte{
990+
constants.MPISSHPublicKey: []byte("EXIST"),
991+
corev1.SSHAuthPrivateKey: []byte("EXIST"),
992+
}).
993+
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
994+
Obj(),
995+
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
996+
WithData(map[string]string{
997+
constants.MPIHostfileName: `trainJob-node-0-0.trainJob slots=1
998+
`,
999+
}).
1000+
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
1001+
Obj(),
1002+
},
1003+
},
8321004
}
8331005
for name, tc := range cases {
8341006
t.Run(name, func(t *testing.T) {

0 commit comments

Comments
 (0)