Skip to content

Commit 4db032a

Browse files
[feature]:add validations for MPIRuntime with RunLauncherAsNode (#2551)
* [feature]:add validatioons for MPIRuntime with RunLauncherAsNode Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * fix:changed "launcher" and "node" with constants Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * fix:removed the constants with string values Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * fix: add nil checks for before derefrencing pointers Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * fix:add check for trainjob being nil Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * updae with suggestions Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * update with suggestions Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * Update pkg/runtime/framework/plugins/mpi/mpi.go Co-authored-by: Yuki Iwai <yuki.iwai.tz@gmail.com> Signed-off-by: Harshal Malani <142563202+Harshal292004@users.noreply.github.com> * Update pkg/runtime/framework/plugins/mpi/mpi.go Co-authored-by: Yuki Iwai <yuki.iwai.tz@gmail.com> Signed-off-by: Harshal Malani <142563202+Harshal292004@users.noreply.github.com> * fix merge conflicts Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * test(runtime): add UT for mpi runtime validate function Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * Update mpi and mpi_test according to suggestions Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> * Update mpi_test Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> --------- Signed-off-by: Harshal292004 <malaniharshal95@gmail.com> Signed-off-by: Harshal Malani <142563202+Harshal292004@users.noreply.github.com> Co-authored-by: Yuki Iwai <yuki.iwai.tz@gmail.com>
1 parent 3ef89f4 commit 4db032a

2 files changed

Lines changed: 74 additions & 2 deletions

File tree

pkg/runtime/framework/plugins/mpi/mpi.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,24 @@ func (m *MPI) Name() string {
8282
// TODO (andreyvelich): Add validation to check that TrainJob doesn't have MPI envs.
8383
// TODO (andreyvelich): We should validate that envs from different plugins don't conflict with each other.
8484
// Ref: https://github.com/kubeflow/trainer/pull/2308#discussion_r1823229940
85-
8685
func (m *MPI) Validate(runtimeInfo *runtime.Info, _, newJobObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
8786
var allErrs field.ErrorList
8887
if runtimeInfo == nil || runtimeInfo.RuntimePolicy.MLPolicySource == nil || runtimeInfo.RuntimePolicy.MLPolicySource.MPI == nil {
8988
return nil, allErrs
9089
}
91-
90+
specPath := field.NewPath("spec")
9291
if trainJobTrainer := newJobObj.Spec.Trainer; trainJobTrainer != nil && trainJobTrainer.NumProcPerNode != nil {
9392
if trainJobTrainer.NumProcPerNode.Type != intstr.Int {
9493
allErrs = append(allErrs, field.Invalid(numProcPerNodePath, *trainJobTrainer.NumProcPerNode, "must have an int value for MPI TrainJob"))
9594
}
9695
}
96+
// validate PodSet configurations based on NumNodes and RunLauncherAsNode.
97+
if trainJobTrainer := newJobObj.Spec.Trainer; trainJobTrainer != nil && ptr.Deref(trainJobTrainer.NumNodes, 1) >= 2 && ptr.Deref(runtimeInfo.RuntimePolicy.MLPolicySource.MPI.RunLauncherAsNode, false) {
98+
if runtimeInfo.FindPodSetByName(constants.Launcher) == nil || runtimeInfo.FindPodSetByName(constants.Node) == nil {
99+
numNodesPath := specPath.Child("trainer", "numNodes")
100+
allErrs = append(allErrs, field.Invalid(numNodesPath, newJobObj.Spec.Trainer.NumNodes, "must have 1 when MPI trainingRuntime with enabled runLauncherAsNode does not have either launcher and node"))
101+
}
102+
}
97103
return nil, allErrs
98104
}
99105

pkg/runtime/framework/plugins/mpi/mpi_test.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,72 @@ func TestValidate(t *testing.T) {
807807
).
808808
Obj(),
809809
},
810+
"runtime does not have Launcher but TrainJob has only 1 numNodes": {
811+
info: runtime.NewInfo(
812+
runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper().
813+
WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper().
814+
MPIPolicy(ptr.To[int32](1), ptr.To(trainer.MPIImplementationOpenMPI), nil, ptr.To(true)).
815+
Obj(),
816+
).
817+
Obj(),
818+
),
819+
runtime.WithPodSet(constants.Node, ptr.To(constants.AncestorTrainer), 1, corev1.PodSpec{}, corev1ac.PodSpec().
820+
WithContainers(corev1ac.Container().WithName(constants.Node)),
821+
),
822+
),
823+
newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Trainer(&trainer.Trainer{NumNodes: ptr.To(int32(1))}).Obj(),
824+
},
825+
"runtime does not have Launcher even though MPI with runLauncherAsNode has 2 numNodes": {
826+
info: runtime.NewInfo(
827+
runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper().
828+
WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper().
829+
MPIPolicy(ptr.To[int32](1), ptr.To(trainer.MPIImplementationOpenMPI), nil, ptr.To(true)).
830+
Obj(),
831+
).
832+
Obj(),
833+
),
834+
runtime.WithPodSet(constants.Node, nil, 1, corev1.PodSpec{}, corev1ac.PodSpec().
835+
WithContainers(corev1ac.Container().WithName(constants.Node)),
836+
),
837+
),
838+
newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Trainer(&trainer.Trainer{NumNodes: ptr.To(int32(2))}).Obj(),
839+
wantError: field.ErrorList{
840+
field.Invalid(field.NewPath("spec").Child("trainer", "numNodes"), ptr.To(int32(2)), "must have 1 when MPI trainingRuntime with enabled runLauncherAsNode does not have either launcher and node"),
841+
},
842+
},
843+
"runtime does not have Node even though MPI with runLauncherAsNode has 2 numNodes": {
844+
info: runtime.NewInfo(
845+
runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper().
846+
WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper().
847+
MPIPolicy(ptr.To[int32](1), ptr.To(trainer.MPIImplementationOpenMPI), nil, ptr.To(true)).
848+
Obj(),
849+
).
850+
Obj(),
851+
),
852+
runtime.WithPodSet(constants.Launcher, nil, 1, corev1.PodSpec{}, corev1ac.PodSpec().
853+
WithContainers(corev1ac.Container().WithName(constants.Node)),
854+
),
855+
),
856+
newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Trainer(&trainer.Trainer{NumNodes: ptr.To(int32(2))}).Obj(),
857+
wantError: field.ErrorList{
858+
field.Invalid(field.NewPath("spec").Child("trainer", "numNodes"), ptr.To(int32(2)), "must have 1 when MPI trainingRuntime with enabled runLauncherAsNode does not have either launcher and node"),
859+
},
860+
},
861+
"runtime does not have Launcher and Node even though MPI with runLauncherAsNode has 2 numNodes": {
862+
info: runtime.NewInfo(
863+
runtime.WithMLPolicySource(utiltesting.MakeMLPolicyWrapper().
864+
WithMLPolicySource(*utiltesting.MakeMLPolicySourceWrapper().
865+
MPIPolicy(ptr.To[int32](1), ptr.To(trainer.MPIImplementationOpenMPI), nil, ptr.To(true)).
866+
Obj(),
867+
).
868+
Obj(),
869+
),
870+
),
871+
newObj: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Trainer(&trainer.Trainer{NumNodes: ptr.To(int32(2))}).Obj(),
872+
wantError: field.ErrorList{
873+
field.Invalid(field.NewPath("spec").Child("trainer", "numNodes"), ptr.To(int32(2)), "must have 1 when MPI trainingRuntime with enabled runLauncherAsNode does not have either launcher and node"),
874+
},
875+
},
810876
}
811877
for name, tc := range cases {
812878
t.Run(name, func(t *testing.T) {

0 commit comments

Comments
 (0)