Skip to content

Commit 880261d

Browse files
committed
Fix launcher job scheduling directives when unsuspending
1 parent 4954c99 commit 880261d

File tree

3 files changed

+244
-34
lines changed

3 files changed

+244
-34
lines changed

pkg/controller/mpi_job_controller.go

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -672,24 +672,43 @@ func (c *MPIJobController) syncHandler(key string) error {
672672
if err != nil {
673673
return err
674674
}
675-
}
676-
if launcher == nil {
677-
if mpiJob.Spec.LauncherCreationPolicy == kubeflow.LauncherCreationPolicyAtStartup || c.countReadyWorkerPods(worker) == len(worker) {
678-
launcher, err = c.kubeClient.BatchV1().Jobs(namespace).Create(context.TODO(), c.newLauncherJob(mpiJob), metav1.CreateOptions{})
679-
if err != nil {
680-
c.recorder.Eventf(mpiJob, corev1.EventTypeWarning, mpiJobFailedReason, "launcher pod created failed: %v", err)
681-
return fmt.Errorf("creating launcher Pod: %w", err)
675+
if launcher == nil {
676+
if mpiJob.Spec.LauncherCreationPolicy == kubeflow.LauncherCreationPolicyAtStartup || c.countReadyWorkerPods(worker) == len(worker) {
677+
launcher, err = c.kubeClient.BatchV1().Jobs(namespace).Create(context.TODO(), c.newLauncherJob(mpiJob), metav1.CreateOptions{})
678+
if err != nil {
679+
c.recorder.Eventf(mpiJob, corev1.EventTypeWarning, mpiJobFailedReason, "launcher pod created failed: %v", err)
680+
return fmt.Errorf("creating launcher Pod: %w", err)
681+
}
682+
} else {
683+
klog.V(4).Infof("Waiting for workers %s/%s to start.", mpiJob.Namespace, mpiJob.Name)
682684
}
683-
} else {
684-
klog.V(4).Infof("Waiting for workers %s/%s to start.", mpiJob.Namespace, mpiJob.Name)
685685
}
686686
}
687687
}
688688

689689
if launcher != nil {
690-
if isMPIJobSuspended(mpiJob) != isJobSuspended(launcher) {
691-
// align the suspension state of launcher with the MPIJob
692-
launcher.Spec.Suspend = ptr.To(isMPIJobSuspended(mpiJob))
690+
if !isMPIJobSuspended(mpiJob) && isJobSuspended(launcher) {
691+
// We are unsuspending, hence we need to sync the pod template with the current MPIJob spec.
692+
// This is important for interop with Kueue as it may have injected schedulingGates.
693+
// Kubernetes validates that a Job template is immutable once StartTime is set,
694+
// so we must clear it first via a status sub-resource update (consistent with JobSet).
695+
if launcher.Status.StartTime != nil {
696+
launcher.Status.StartTime = nil
697+
if _, err := c.kubeClient.BatchV1().Jobs(namespace).UpdateStatus(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
698+
return err
699+
}
700+
}
701+
702+
// Sync mutable scheduling directives (KEP-2926) and unsuspend.
703+
desiredPodTemplate := c.newLauncherPodTemplate(mpiJob)
704+
syncLauncherSchedulingDirectives(launcher, &desiredPodTemplate)
705+
launcher.Spec.Suspend = ptr.To(false)
706+
if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
707+
return err
708+
}
709+
} else if isMPIJobSuspended(mpiJob) && !isJobSuspended(launcher) {
710+
// align the suspension state of launcher with the MPIJob.
711+
launcher.Spec.Suspend = ptr.To(true)
693712
if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil {
694713
return err
695714
}
@@ -1623,6 +1642,30 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev
16231642
}
16241643
}
16251644

1645+
// syncLauncherSchedulingDirectives updates the mutable scheduling directives (as per KEP-2926) on
1646+
// the launcher Job's pod template to match the desired template.
1647+
func syncLauncherSchedulingDirectives(launcher *batchv1.Job, desired *corev1.PodTemplateSpec) {
1648+
if launcher.Spec.Template.Labels == nil {
1649+
launcher.Spec.Template.Labels = make(map[string]string)
1650+
}
1651+
for k, v := range desired.Labels {
1652+
launcher.Spec.Template.Labels[k] = v
1653+
}
1654+
1655+
if desired.Annotations != nil {
1656+
if launcher.Spec.Template.Annotations == nil {
1657+
launcher.Spec.Template.Annotations = make(map[string]string)
1658+
}
1659+
for k, v := range desired.Annotations {
1660+
launcher.Spec.Template.Annotations[k] = v
1661+
}
1662+
}
1663+
1664+
launcher.Spec.Template.Spec.NodeSelector = desired.Spec.NodeSelector
1665+
launcher.Spec.Template.Spec.Tolerations = desired.Spec.Tolerations
1666+
launcher.Spec.Template.Spec.SchedulingGates = desired.Spec.SchedulingGates
1667+
}
1668+
16261669
func (c *MPIJobController) jobPods(j *batchv1.Job) ([]*corev1.Pod, error) {
16271670
selector, err := metav1.LabelSelectorAsSelector(j.Spec.Selector)
16281671
if err != nil {

pkg/controller/mpi_job_controller_test.go

Lines changed: 179 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -869,17 +869,10 @@ func TestCreateSuspendedMPIJob(t *testing.T) {
869869
}
870870
f.expectCreateSecretAction(secret)
871871

872-
// expect creating of the launcher
873-
fmjc := f.newFakeMPIJobController()
874-
launcher := fmjc.newLauncherJob(mpiJob)
875-
launcher.Spec.Suspend = ptr.To(true)
876-
f.expectCreateJobAction(launcher)
877-
878872
// expect an update to add the conditions
879873
mpiJobCopy := mpiJob.DeepCopy()
880874
mpiJobCopy.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
881-
kubeflow.MPIReplicaTypeLauncher: {},
882-
kubeflow.MPIReplicaTypeWorker: {},
875+
kubeflow.MPIReplicaTypeWorker: {},
883876
}
884877
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name)
885878
updateMPIJobConditions(mpiJobCopy, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg)
@@ -984,7 +977,67 @@ func TestResumeMPIJob(t *testing.T) {
984977
fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second))
985978
f := newFixture(t, "")
986979

987-
// create a suspended job
980+
// create a suspended job (no launcher exists, it's not created while the MPIJob is suspended)
981+
var replicas int32 = 8
982+
mpiJob := newMPIJob("test", &replicas, nil, nil)
983+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(true)
984+
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name)
985+
updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg)
986+
updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended")
987+
msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name)
988+
updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg)
989+
mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
990+
kubeflow.MPIReplicaTypeLauncher: {},
991+
kubeflow.MPIReplicaTypeWorker: {},
992+
}
993+
f.setUpMPIJob(mpiJob)
994+
995+
// set up existing objects
996+
scheme.Scheme.Default(mpiJob)
997+
f.expectCreateServiceAction(newJobService(mpiJob))
998+
cfgMap := newConfigMap(mpiJob, replicas, "")
999+
updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "")
1000+
f.setUpConfigMap(cfgMap)
1001+
secret, err := newSSHAuthSecret(mpiJob)
1002+
if err != nil {
1003+
t.Fatalf("Failed creating secret")
1004+
}
1005+
f.setUpSecret(secret)
1006+
1007+
fmjc := f.newFakeMPIJobController()
1008+
1009+
// move the timer by a second so that the StartTime is updated after resume
1010+
fakeClock.Sleep(time.Second)
1011+
1012+
// resume the MPIJob
1013+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
1014+
1015+
// expect creation of the worker pods
1016+
for i := 0; i < int(replicas); i++ {
1017+
worker := fmjc.newWorker(mpiJob, i)
1018+
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
1019+
}
1020+
1021+
// expect the launcher to be created since it didn't exist while suspended
1022+
launcher := fmjc.newLauncherJob(mpiJob)
1023+
f.expectCreateJobAction(launcher)
1024+
1025+
// expect an update to add the conditions
1026+
mpiJobCopy := mpiJob.DeepCopy()
1027+
mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()}
1028+
updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed")
1029+
f.expectUpdateMPIJobStatusAction(mpiJobCopy)
1030+
1031+
f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
1032+
}
1033+
1034+
func TestResumeMPIJobWithExistingLauncher(t *testing.T) {
1035+
// Tests the running→suspended→resumed path where a launcher already exists
1036+
// (from before suspension) with startTime == nil. The launcher should be
1037+
// updated in place with synced scheduling directives (KEP-2926).
1038+
fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second))
1039+
f := newFixture(t, "")
1040+
9881041
var replicas int32 = 8
9891042
startTime := metav1.Now()
9901043
mpiJob := newMPIJob("test", &replicas, &startTime, nil)
@@ -1000,7 +1053,6 @@ func TestResumeMPIJob(t *testing.T) {
10001053
}
10011054
f.setUpMPIJob(mpiJob)
10021055

1003-
// expect creation of objects
10041056
scheme.Scheme.Default(mpiJob)
10051057
f.expectCreateServiceAction(newJobService(mpiJob))
10061058
cfgMap := newConfigMap(mpiJob, replicas, "")
@@ -1012,30 +1064,142 @@ func TestResumeMPIJob(t *testing.T) {
10121064
}
10131065
f.setUpSecret(secret)
10141066

1015-
// expect creating of the launcher
1067+
// set up an existing suspended launcher (startTime == nil, never started)
10161068
fmjc := f.newFakeMPIJobController()
10171069
launcher := fmjc.newLauncherJob(mpiJob)
10181070
launcher.Spec.Suspend = ptr.To(true)
1071+
// Simulate Kueue injecting scheduling directives into the MPIJob template
1072+
// after the launcher was already created (so the launcher has stale templates).
1073+
launcherSpec := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template
1074+
launcherSpec.Spec.NodeSelector = map[string]string{
1075+
"foo": "bar",
1076+
}
1077+
launcherSpec.Spec.Tolerations = []corev1.Toleration{
1078+
{Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule},
1079+
}
1080+
launcherSpec.Spec.SchedulingGates = []corev1.PodSchedulingGate{
1081+
{Name: "kueue.x-k8s.io/topology"},
1082+
}
1083+
if launcherSpec.Annotations == nil {
1084+
launcherSpec.Annotations = make(map[string]string)
1085+
}
1086+
launcherSpec.Annotations["kueue.x-k8s.io/workload"] = "my-workload"
1087+
f.setUpLauncher(launcher)
1088+
1089+
fakeClock.Sleep(time.Second)
1090+
1091+
// resume the MPIJob
1092+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
1093+
1094+
// expect creation of the worker pods
1095+
for i := 0; i < int(replicas); i++ {
1096+
worker := fmjc.newWorker(mpiJob, i)
1097+
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
1098+
}
1099+
1100+
// expect the launcher to be updated (scheduling directives synced + unsuspended)
1101+
launcherCopy := launcher.DeepCopy()
1102+
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
1103+
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
1104+
launcherCopy.Spec.Suspend = ptr.To(false)
1105+
1106+
// Verify the synced launcher has the Kueue-injected scheduling directives.
1107+
tmpl := &launcherCopy.Spec.Template
1108+
if tmpl.Spec.NodeSelector["foo"] != "bar" {
1109+
t.Errorf("expected nodeSelector to be synced, got %v", tmpl.Spec.NodeSelector)
1110+
}
1111+
if len(tmpl.Spec.Tolerations) != 1 || tmpl.Spec.Tolerations[0].Key != "gpu" {
1112+
t.Errorf("expected tolerations to be synced, got %v", tmpl.Spec.Tolerations)
1113+
}
1114+
if len(tmpl.Spec.SchedulingGates) != 1 || tmpl.Spec.SchedulingGates[0].Name != "kueue.x-k8s.io/topology" {
1115+
t.Errorf("expected schedulingGates to be synced, got %v", tmpl.Spec.SchedulingGates)
1116+
}
1117+
if tmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" {
1118+
t.Errorf("expected annotations to be synced, got %v", tmpl.Annotations)
1119+
}
1120+
1121+
f.expectUpdateJobAction(launcherCopy)
1122+
1123+
// expect status update
1124+
mpiJobCopy := mpiJob.DeepCopy()
1125+
mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()}
1126+
updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed")
1127+
f.expectUpdateMPIJobStatusAction(mpiJobCopy)
1128+
1129+
f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock)
1130+
}
1131+
1132+
func TestResumeMPIJobClearsStartTime(t *testing.T) {
1133+
// Tests the re-admission case where the launcher has startTime != nil.
1134+
// The controller should clear StartTime via a status sub-resource update
1135+
// (consistent with JobSet), then sync scheduling directives and unsuspend.
1136+
fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second))
1137+
f := newFixture(t, "")
1138+
1139+
var replicas int32 = 8
1140+
startTime := metav1.Now()
1141+
mpiJob := newMPIJob("test", &replicas, &startTime, nil)
1142+
mpiJob.Spec.RunPolicy.Suspend = ptr.To(true)
1143+
msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name)
1144+
updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg)
1145+
updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended")
1146+
msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name)
1147+
updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg)
1148+
mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
1149+
kubeflow.MPIReplicaTypeLauncher: {},
1150+
kubeflow.MPIReplicaTypeWorker: {},
1151+
}
1152+
f.setUpMPIJob(mpiJob)
1153+
1154+
scheme.Scheme.Default(mpiJob)
1155+
f.expectCreateServiceAction(newJobService(mpiJob))
1156+
cfgMap := newConfigMap(mpiJob, replicas, "")
1157+
updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "")
1158+
f.setUpConfigMap(cfgMap)
1159+
secret, err := newSSHAuthSecret(mpiJob)
1160+
if err != nil {
1161+
t.Fatalf("Failed creating secret")
1162+
}
1163+
f.setUpSecret(secret)
1164+
1165+
// set up an existing suspended launcher that was previously started (startTime != nil)
1166+
fmjc := f.newFakeMPIJobController()
1167+
launcher := fmjc.newLauncherJob(mpiJob)
1168+
launcher.Spec.Suspend = ptr.To(true)
1169+
launcherStartTime := metav1.Now()
1170+
launcher.Status.StartTime = &launcherStartTime
10191171
f.setUpLauncher(launcher)
10201172

1021-
// move the timer by a second so that the StartTime is updated after resume
10221173
fakeClock.Sleep(time.Second)
10231174

10241175
// resume the MPIJob
10251176
mpiJob.Spec.RunPolicy.Suspend = ptr.To(false)
10261177

1027-
// expect creation of the pods
1178+
// expect creation of worker pods
10281179
for i := 0; i < int(replicas); i++ {
10291180
worker := fmjc.newWorker(mpiJob, i)
10301181
f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker))
10311182
}
10321183

1033-
// expect the launcher update to resume it
1184+
// expect a status sub-resource update to clear launcher's StartTime
1185+
launcherStatusCleared := launcher.DeepCopy()
1186+
launcherStatusCleared.Status.StartTime = nil
1187+
f.kubeActions = append(f.kubeActions, core.NewUpdateSubresourceAction(
1188+
schema.GroupVersionResource{Resource: "jobs", Group: "batch", Version: "v1"},
1189+
"status",
1190+
mpiJob.Namespace,
1191+
launcherStatusCleared,
1192+
))
1193+
1194+
// expect the launcher to be updated (scheduling directives synced + unsuspended)
10341195
launcherCopy := launcher.DeepCopy()
1196+
launcherCopy.Status.StartTime = nil
1197+
desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob)
1198+
syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate)
10351199
launcherCopy.Spec.Suspend = ptr.To(false)
10361200
f.expectUpdateJobAction(launcherCopy)
10371201

1038-
// expect an update to add the conditions
1202+
// expect MPIJob status update
10391203
mpiJobCopy := mpiJob.DeepCopy()
10401204
mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()}
10411205
updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed")

test/integration/mpi_job_controller_test.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {
357357
},
358358
},
359359
}
360-
// 1. Create suspended MPIJob
360+
// 1. Create suspended MPIJob (no launcher Job created while suspended)
361361
var err error
362362
mpiJob, err = s.mpiClient.KubeflowV2beta1().MPIJobs(s.namespace).Create(ctx, mpiJob, metav1.CreateOptions{})
363363
if err != nil {
@@ -373,20 +373,23 @@ func TestMPIJobResumingAndSuspending(t *testing.T) {
373373
Reason: "MPIJobSuspended",
374374
}, mpiJob))
375375

376-
_, launcherJob := validateMPIJobDependencies(ctx, t, s.kClient, mpiJob, 0, nil)
377376
mpiJob = validateMPIJobStatus(ctx, t, s.mpiClient, mpiJob, map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{
378-
kubeflow.MPIReplicaTypeLauncher: {},
379-
kubeflow.MPIReplicaTypeWorker: {},
377+
kubeflow.MPIReplicaTypeWorker: {},
380378
})
379+
// Verify no launcher Job exists while suspended.
380+
launcherJob, err := getLauncherJobForMPIJob(ctx, s.kClient, mpiJob)
381+
if err != nil {
382+
t.Fatalf("Getting launcher Job: %v", err)
383+
}
384+
if launcherJob != nil {
385+
t.Errorf("Launcher Job should not exist while MPIJob is suspended")
386+
}
381387
if !mpiJobHasCondition(mpiJob, kubeflow.JobCreated) {
382388
t.Errorf("MPIJob missing Created condition")
383389
}
384390
if !mpiJobHasCondition(mpiJob, kubeflow.JobSuspended) {
385391
t.Errorf("MPIJob missing Suspended condition")
386392
}
387-
if !isJobSuspended(launcherJob) {
388-
t.Errorf("LauncherJob is suspended")
389-
}
390393
if mpiJob.Status.StartTime != nil {
391394
t.Errorf("MPIJob has unexpected start time: %v", mpiJob.Status.StartTime)
392395
}

0 commit comments

Comments
 (0)