diff --git a/pkg/controller/mpi_job_controller.go b/pkg/controller/mpi_job_controller.go index 47a1b589..22683d2e 100644 --- a/pkg/controller/mpi_job_controller.go +++ b/pkg/controller/mpi_job_controller.go @@ -687,9 +687,28 @@ func (c *MPIJobController) syncHandler(key string) error { } if launcher != nil { - if isMPIJobSuspended(mpiJob) != isJobSuspended(launcher) { - // align the suspension state of launcher with the MPIJob - launcher.Spec.Suspend = ptr.To(isMPIJobSuspended(mpiJob)) + if !isMPIJobSuspended(mpiJob) && isJobSuspended(launcher) { + // We are unsuspending, hence we need to sync the pod template with the current MPIJob spec. + // This is important for interop with Kueue as it may have injected schedulingGates. + // Kubernetes validates that a Job template is immutable once StartTime is set, + // so we must clear it first via a status sub-resource update (consistent with JobSet). + if launcher.Status.StartTime != nil { + launcher.Status.StartTime = nil + if _, err := c.kubeClient.BatchV1().Jobs(namespace).UpdateStatus(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil { + return err + } + } + + // Sync mutable scheduling directives (KEP-2926) and unsuspend. + desiredPodTemplate := c.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcher, &desiredPodTemplate) + launcher.Spec.Suspend = ptr.To(false) + if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil { + return err + } + } else if isMPIJobSuspended(mpiJob) && !isJobSuspended(launcher) { + // align the suspension state of launcher with the MPIJob. + launcher.Spec.Suspend = ptr.To(true) if _, err := c.kubeClient.BatchV1().Jobs(namespace).Update(context.TODO(), launcher, metav1.UpdateOptions{}); err != nil { return err } @@ -1623,6 +1642,30 @@ func (c *MPIJobController) newLauncherPodTemplate(mpiJob *kubeflow.MPIJob) corev } } +// syncLauncherSchedulingDirectives updates the mutable scheduling directives (as per KEP-2926) on +// the launcher Job's pod template to match the desired template. +func syncLauncherSchedulingDirectives(launcher *batchv1.Job, desired *corev1.PodTemplateSpec) { + if launcher.Spec.Template.Labels == nil { + launcher.Spec.Template.Labels = make(map[string]string) + } + for k, v := range desired.Labels { + launcher.Spec.Template.Labels[k] = v + } + + if desired.Annotations != nil { + if launcher.Spec.Template.Annotations == nil { + launcher.Spec.Template.Annotations = make(map[string]string) + } + for k, v := range desired.Annotations { + launcher.Spec.Template.Annotations[k] = v + } + } + + launcher.Spec.Template.Spec.NodeSelector = desired.Spec.NodeSelector + launcher.Spec.Template.Spec.Tolerations = desired.Spec.Tolerations + launcher.Spec.Template.Spec.SchedulingGates = desired.Spec.SchedulingGates +} + func (c *MPIJobController) jobPods(j *batchv1.Job) ([]*corev1.Pod, error) { selector, err := metav1.LabelSelectorAsSelector(j.Spec.Selector) if err != nil { diff --git a/pkg/controller/mpi_job_controller_test.go b/pkg/controller/mpi_job_controller_test.go index ea39f21c..888a7f4a 100644 --- a/pkg/controller/mpi_job_controller_test.go +++ b/pkg/controller/mpi_job_controller_test.go @@ -1024,14 +1024,16 @@ func TestResumeMPIJob(t *testing.T) { // resume the MPIJob mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) - // expect creation of the pods + // expect creation of the worker pods for i := 0; i < int(replicas); i++ { worker := fmjc.newWorker(mpiJob, i) f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker)) } - // expect the launcher update to resume it + // expect the launcher update to sync scheduling directives and resume it launcherCopy := launcher.DeepCopy() + desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate) launcherCopy.Spec.Suspend = ptr.To(false) f.expectUpdateJobAction(launcherCopy) @@ -1044,6 +1046,183 @@ func TestResumeMPIJob(t *testing.T) { f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock) } +func TestResumeMPIJobWithExistingLauncher(t *testing.T) { + // Tests the running→suspended→resumed path where a launcher already exists + // (from before suspension) with startTime == nil. The launcher should be + // updated in place with synced scheduling directives (KEP-2926). + fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second)) + f := newFixture(t, "") + + var replicas int32 = 8 + startTime := metav1.Now() + mpiJob := newMPIJob("test", &replicas, &startTime, nil) + mpiJob.Spec.RunPolicy.Suspend = ptr.To(true) + msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg) + updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended") + msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg) + mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{ + kubeflow.MPIReplicaTypeLauncher: {}, + kubeflow.MPIReplicaTypeWorker: {}, + } + f.setUpMPIJob(mpiJob) + + scheme.Scheme.Default(mpiJob) + f.expectCreateServiceAction(newJobService(mpiJob)) + cfgMap := newConfigMap(mpiJob, replicas, "") + updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "") + f.setUpConfigMap(cfgMap) + secret, err := newSSHAuthSecret(mpiJob) + if err != nil { + t.Fatalf("Failed creating secret") + } + f.setUpSecret(secret) + + // set up an existing suspended launcher (startTime == nil, never started) + fmjc := f.newFakeMPIJobController() + launcher := fmjc.newLauncherJob(mpiJob) + launcher.Spec.Suspend = ptr.To(true) + // Simulate Kueue injecting scheduling directives into the MPIJob template + // after the launcher was already created (so the launcher has stale templates). + launcherSpec := &mpiJob.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher].Template + launcherSpec.Spec.NodeSelector = map[string]string{ + "foo": "bar", + } + launcherSpec.Spec.Tolerations = []corev1.Toleration{ + {Key: "gpu", Operator: corev1.TolerationOpEqual, Value: "true", Effect: corev1.TaintEffectNoSchedule}, + } + launcherSpec.Spec.SchedulingGates = []corev1.PodSchedulingGate{ + {Name: "kueue.x-k8s.io/topology"}, + } + if launcherSpec.Annotations == nil { + launcherSpec.Annotations = make(map[string]string) + } + launcherSpec.Annotations["kueue.x-k8s.io/workload"] = "my-workload" + f.setUpLauncher(launcher) + + fakeClock.Sleep(time.Second) + + // resume the MPIJob + mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) + + // expect creation of the worker pods + for i := 0; i < int(replicas); i++ { + worker := fmjc.newWorker(mpiJob, i) + f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker)) + } + + // expect the launcher to be updated (scheduling directives synced + unsuspended) + launcherCopy := launcher.DeepCopy() + desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate) + launcherCopy.Spec.Suspend = ptr.To(false) + + // Verify the synced launcher has the Kueue-injected scheduling directives. + tmpl := &launcherCopy.Spec.Template + if tmpl.Spec.NodeSelector["foo"] != "bar" { + t.Errorf("expected nodeSelector to be synced, got %v", tmpl.Spec.NodeSelector) + } + if len(tmpl.Spec.Tolerations) != 1 || tmpl.Spec.Tolerations[0].Key != "gpu" { + t.Errorf("expected tolerations to be synced, got %v", tmpl.Spec.Tolerations) + } + if len(tmpl.Spec.SchedulingGates) != 1 || tmpl.Spec.SchedulingGates[0].Name != "kueue.x-k8s.io/topology" { + t.Errorf("expected schedulingGates to be synced, got %v", tmpl.Spec.SchedulingGates) + } + if tmpl.Annotations["kueue.x-k8s.io/workload"] != "my-workload" { + t.Errorf("expected annotations to be synced, got %v", tmpl.Annotations) + } + + f.expectUpdateJobAction(launcherCopy) + + // expect status update + mpiJobCopy := mpiJob.DeepCopy() + mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()} + updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed") + f.expectUpdateMPIJobStatusAction(mpiJobCopy) + + f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock) +} + +func TestResumeMPIJobClearsStartTime(t *testing.T) { + // Tests the re-admission case where the launcher has startTime != nil. + // The controller should clear StartTime via a status sub-resource update + // (consistent with JobSet), then sync scheduling directives and unsuspend. + fakeClock := clocktesting.NewFakeClock(time.Now().Truncate(time.Second)) + f := newFixture(t, "") + + var replicas int32 = 8 + startTime := metav1.Now() + mpiJob := newMPIJob("test", &replicas, &startTime, nil) + mpiJob.Spec.RunPolicy.Suspend = ptr.To(true) + msg := fmt.Sprintf("MPIJob %s/%s is created.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobCreated, corev1.ConditionTrue, mpiJobCreatedReason, msg) + updateMPIJobConditions(mpiJob, kubeflow.JobSuspended, corev1.ConditionTrue, mpiJobSuspendedReason, "MPIJob suspended") + msg = fmt.Sprintf("MPIJob %s/%s is suspended.", mpiJob.Namespace, mpiJob.Name) + updateMPIJobConditions(mpiJob, kubeflow.JobRunning, corev1.ConditionFalse, mpiJobSuspendedReason, msg) + mpiJob.Status.ReplicaStatuses = map[kubeflow.MPIReplicaType]*kubeflow.ReplicaStatus{ + kubeflow.MPIReplicaTypeLauncher: {}, + kubeflow.MPIReplicaTypeWorker: {}, + } + f.setUpMPIJob(mpiJob) + + scheme.Scheme.Default(mpiJob) + f.expectCreateServiceAction(newJobService(mpiJob)) + cfgMap := newConfigMap(mpiJob, replicas, "") + updateDiscoverHostsInConfigMap(cfgMap, mpiJob, nil, "") + f.setUpConfigMap(cfgMap) + secret, err := newSSHAuthSecret(mpiJob) + if err != nil { + t.Fatalf("Failed creating secret") + } + f.setUpSecret(secret) + + // set up an existing suspended launcher that was previously started (startTime != nil) + fmjc := f.newFakeMPIJobController() + launcher := fmjc.newLauncherJob(mpiJob) + launcher.Spec.Suspend = ptr.To(true) + launcherStartTime := metav1.Now() + launcher.Status.StartTime = &launcherStartTime + f.setUpLauncher(launcher) + + fakeClock.Sleep(time.Second) + + // resume the MPIJob + mpiJob.Spec.RunPolicy.Suspend = ptr.To(false) + + // expect creation of worker pods + for i := 0; i < int(replicas); i++ { + worker := fmjc.newWorker(mpiJob, i) + f.kubeActions = append(f.kubeActions, core.NewCreateAction(schema.GroupVersionResource{Resource: "pods"}, mpiJob.Namespace, worker)) + } + + // expect a status sub-resource update to clear launcher's StartTime + launcherStatusCleared := launcher.DeepCopy() + launcherStatusCleared.Status.StartTime = nil + f.kubeActions = append(f.kubeActions, core.NewUpdateSubresourceAction( + schema.GroupVersionResource{Resource: "jobs", Group: "batch", Version: "v1"}, + "status", + mpiJob.Namespace, + launcherStatusCleared, + )) + + // expect the launcher to be updated (scheduling directives synced + unsuspended) + launcherCopy := launcher.DeepCopy() + launcherCopy.Status.StartTime = nil + desiredPodTemplate := fmjc.newLauncherPodTemplate(mpiJob) + syncLauncherSchedulingDirectives(launcherCopy, &desiredPodTemplate) + launcherCopy.Spec.Suspend = ptr.To(false) + f.expectUpdateJobAction(launcherCopy) + + // expect MPIJob status update + mpiJobCopy := mpiJob.DeepCopy() + mpiJobCopy.Status.StartTime = &metav1.Time{Time: fakeClock.Now()} + updateMPIJobConditions(mpiJobCopy, kubeflow.JobSuspended, corev1.ConditionFalse, "MPIJobResumed", "MPIJob resumed") + f.expectUpdateMPIJobStatusAction(mpiJobCopy) + + f.runWithClock(t.Context(), getKey(mpiJob, t), fakeClock) +} + func TestWorkerNotControlledByUs(t *testing.T) { f := newFixture(t, "") startTime := metav1.Now()