@@ -829,6 +829,178 @@ trainJob-node-1-0.trainJob slots=1
829829 },
830830 wantBuildError : errorGetSSHAuthSecretFromAPI ,
831831 },
832+ "environment variables propagation and filtering for OpenMPI" : {
833+ trainJob : utiltesting .MakeTrainJobWrapper (metav1 .NamespaceDefault , "trainJob" ).
834+ UID ("trainJob" ).
835+ Trainer (utiltesting .MakeTrainJobTrainerWrapper ().
836+ Env (
837+ corev1.EnvVar {Name : "STATIC_VAR" , Value : "static_value" },
838+ corev1.EnvVar {Name : "POD_NAME" , ValueFrom : & corev1.EnvVarSource {FieldRef : & corev1.ObjectFieldSelector {FieldPath : "metadata.name" }}},
839+ ).
840+ Obj ()).
841+ Obj (),
842+ info : & runtime.Info {
843+ RuntimePolicy : runtime.RuntimePolicy {
844+ MLPolicySource : utiltesting .MakeMLPolicySourceWrapper ().
845+ MPIPolicy (ptr.To [int32 ](1 ), trainer .MPIImplementationOpenMPI , ptr .To ("/root/.ssh" ), nil ).
846+ Obj (),
847+ },
848+ TemplateSpec : runtime.TemplateSpec {
849+ PodSets : []runtime.PodSet {
850+ {
851+ Name : constants .Launcher ,
852+ Count : ptr.To [int32 ](1 ),
853+ Endpoints : func (yield func (string ) bool ) {
854+ yield ("trainJob-launcher-0-0.trainJob" )
855+ },
856+ Containers : []runtime.Container {{
857+ Name : constants .Node ,
858+ }},
859+ },
860+ {
861+ Name : constants .Node ,
862+ Count : ptr.To [int32 ](1 ),
863+ Endpoints : func (yield func (string ) bool ) {
864+ yield ("trainJob-node-0-0.trainJob" )
865+ },
866+ Containers : []runtime.Container {{
867+ Name : constants .Node ,
868+ }},
869+ },
870+ },
871+ },
872+ Scheduler : & runtime.Scheduler {PodLabels : make (map [string ]string )},
873+ },
874+ wantInfo : & runtime.Info {
875+ Labels : nil ,
876+ Annotations : nil ,
877+ RuntimePolicy : runtime.RuntimePolicy {
878+ MLPolicySource : utiltesting .MakeMLPolicySourceWrapper ().
879+ MPIPolicy (ptr.To [int32 ](1 ), trainer .MPIImplementationOpenMPI , ptr .To ("/root/.ssh" ), nil ).
880+ Obj (),
881+ },
882+ TemplateSpec : runtime.TemplateSpec {
883+ PodSets : []runtime.PodSet {
884+ {
885+ Name : constants .Launcher ,
886+ Count : ptr.To [int32 ](1 ),
887+ Containers : []runtime.Container {{
888+ Name : constants .Node ,
889+ Env : []corev1ac.EnvVarApplyConfiguration {
890+ * corev1ac .EnvVar ().
891+ WithName (constants .OpenMPIEnvHostFileLocation ).
892+ WithValue (fmt .Sprintf ("%s/%s" , constants .MPIHostfileDir , constants .MPIHostfileName )),
893+ * corev1ac .EnvVar ().
894+ WithName (constants .OpenMPIEnvKeepFQDNHostNames ).
895+ WithValue ("true" ),
896+ * corev1ac .EnvVar ().
897+ WithName (constants .OpenMPIEnvDefaultSlots ).
898+ WithValue ("1" ),
899+ * corev1ac .EnvVar ().
900+ WithName (constants .OpenMPIEnvKeyRSHArgs ).
901+ WithValue (constants .OpenMPIEnvDefaultValueRSHArgs ),
902+ * corev1ac .EnvVar ().
903+ WithName (constants .OpenMPIEnvBaseEnvList ).
904+ WithValue ("STATIC_VAR" ),
905+ },
906+ VolumeMounts : []corev1ac.VolumeMountApplyConfiguration {
907+ * corev1ac .VolumeMount ().
908+ WithName (constants .MPISSHAuthVolumeName ).
909+ WithMountPath ("/root/.ssh" ),
910+ * corev1ac .VolumeMount ().
911+ WithName (constants .MPIHostfileVolumeName ).
912+ WithMountPath (constants .MPIHostfileDir ),
913+ },
914+ }},
915+ Volumes : []corev1ac.VolumeApplyConfiguration {
916+ * corev1ac .Volume ().
917+ WithName (constants .MPISSHAuthVolumeName ).
918+ WithSecret (corev1ac .SecretVolumeSource ().
919+ WithSecretName (fmt .Sprintf ("trainJob%s" , constants .MPISSHAuthSecretSuffix )).
920+ WithItems (
921+ corev1ac .KeyToPath ().
922+ WithKey (corev1 .SSHAuthPrivateKey ).
923+ WithPath (constants .MPISSHPrivateKeyFile ),
924+ corev1ac .KeyToPath ().
925+ WithKey (constants .MPISSHPublicKey ).
926+ WithPath (constants .MPISSHPublicKeyFile ),
927+ corev1ac .KeyToPath ().
928+ WithKey (constants .MPISSHPublicKey ).
929+ WithPath (constants .MPISSHAuthorizedKeys ),
930+ ),
931+ ),
932+ * corev1ac .Volume ().
933+ WithName (constants .MPIHostfileVolumeName ).
934+ WithConfigMap (corev1ac .ConfigMapVolumeSource ().
935+ WithName (fmt .Sprintf ("trainJob%s" , constants .MPIHostfileConfigMapSuffix )).
936+ WithItems (
937+ corev1ac .KeyToPath ().
938+ WithKey (constants .MPIHostfileName ).
939+ WithPath (constants .MPIHostfileName ).
940+ WithMode (0444 ),
941+ ),
942+ ),
943+ },
944+ Endpoints : func (yield func (string ) bool ) {
945+ yield ("trainJob-launcher-0-0.trainJob" )
946+ },
947+ },
948+ {
949+ Name : constants .Node ,
950+ Count : ptr.To [int32 ](1 ),
951+ Containers : []runtime.Container {{
952+ Name : constants .Node ,
953+ VolumeMounts : []corev1ac.VolumeMountApplyConfiguration {
954+ * corev1ac .VolumeMount ().
955+ WithName (constants .MPISSHAuthVolumeName ).
956+ WithMountPath ("/root/.ssh" ),
957+ },
958+ }},
959+ Volumes : []corev1ac.VolumeApplyConfiguration {
960+ * corev1ac .Volume ().
961+ WithName (constants .MPISSHAuthVolumeName ).
962+ WithSecret (corev1ac .SecretVolumeSource ().
963+ WithSecretName (fmt .Sprintf ("trainJob%s" , constants .MPISSHAuthSecretSuffix )).
964+ WithItems (
965+ corev1ac .KeyToPath ().
966+ WithKey (corev1 .SSHAuthPrivateKey ).
967+ WithPath (constants .MPISSHPrivateKeyFile ),
968+ corev1ac .KeyToPath ().
969+ WithKey (constants .MPISSHPublicKey ).
970+ WithPath (constants .MPISSHPublicKeyFile ),
971+ corev1ac .KeyToPath ().
972+ WithKey (constants .MPISSHPublicKey ).
973+ WithPath (constants .MPISSHAuthorizedKeys ),
974+ ),
975+ ),
976+ },
977+ Endpoints : func (yield func (string ) bool ) {
978+ yield ("trainJob-node-0-0.trainJob" )
979+ },
980+ },
981+ },
982+ },
983+ Scheduler : & runtime.Scheduler {PodLabels : make (map [string ]string )},
984+ },
985+ wantObjs : []apiruntime.Object {
986+ utiltesting .MakeSecretWrapper (fmt .Sprintf ("trainJob%s" , constants .MPISSHAuthSecretSuffix ), metav1 .NamespaceDefault ).
987+ WithImmutable (true ).
988+ WithType (corev1 .SecretTypeSSHAuth ).
989+ WithData (map [string ][]byte {
990+ constants .MPISSHPublicKey : []byte ("EXIST" ),
991+ corev1 .SSHAuthPrivateKey : []byte ("EXIST" ),
992+ }).
993+ ControllerReference (trainer .SchemeGroupVersion .WithKind (trainer .TrainJobKind ), "trainJob" , "trainJob" ).
994+ Obj (),
995+ utiltesting .MakeConfigMapWrapper (fmt .Sprintf ("trainJob%s" , constants .MPIHostfileConfigMapSuffix ), metav1 .NamespaceDefault ).
996+ WithData (map [string ]string {
997+ constants .MPIHostfileName : `trainJob-node-0-0.trainJob slots=1
998+ ` ,
999+ }).
1000+ ControllerReference (trainer .SchemeGroupVersion .WithKind (trainer .TrainJobKind ), "trainJob" , "trainJob" ).
1001+ Obj (),
1002+ },
1003+ },
8321004 }
8331005 for name , tc := range cases {
8341006 t .Run (name , func (t * testing.T ) {
0 commit comments