Skip to content

Commit d1848ce

Browse files
committed
feat(api): Add terminationGracePeriodSeconds to PodSpecPatch in TrainJob
Adds terminationGracePeriodSeconds field to PodSpecPatch so users can configure pod termination grace period per TrainJob via RuntimePatches. This is needed for distributed training with PyTorch Elastic (torchrun) where large models (70B+ parameters) require more than the default 30s to complete JIT checkpointing before SIGKILL on node drain or TrainJob pause. No changes to merge logic in trainingruntime.go are required since the existing StrategicMergePatch applied at batchv1.JobTemplateSpec level already handles this field automatically. Closes #3285 Signed-off-by: krishdef7 <gargkrish06@gmail.com>
1 parent 941f4a2 commit d1848ce

File tree

11 files changed

+198
-2
lines changed

11 files changed

+198
-2
lines changed

api/openapi-spec/swagger.json

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

api/python_api/kubeflow_trainer_api/models/trainer_v1alpha1_pod_spec_patch.py

Lines changed: 4 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

charts/kubeflow-trainer/crds/trainer.kubeflow.org_trainjobs.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3122,6 +3122,14 @@ spec:
31223122
x-kubernetes-validations:
31233123
- message: field is immutable
31243124
rule: self == oldSelf
3125+
terminationGracePeriodSeconds:
3126+
description: |-
3127+
terminationGracePeriodSeconds patches the termination grace period for Pods
3128+
in the target job templates. This allows users to configure sufficient time
3129+
for checkpoint saving on TrainJob completion or node drain.
3130+
format: int64
3131+
minimum: 0
3132+
type: integer
31253133
tolerations:
31263134
description: tolerations patches
31273135
the Pod's tolerations.

manifests/base/crds/trainer.kubeflow.org_trainjobs.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3122,6 +3122,14 @@ spec:
31223122
x-kubernetes-validations:
31233123
- message: field is immutable
31243124
rule: self == oldSelf
3125+
terminationGracePeriodSeconds:
3126+
description: |-
3127+
terminationGracePeriodSeconds patches the termination grace period for Pods
3128+
in the target job templates. This allows users to configure sufficient time
3129+
for checkpoint saving on TrainJob completion or node drain.
3130+
format: int64
3131+
minimum: 0
3132+
type: integer
31253133
tolerations:
31263134
description: tolerations patches
31273135
the Pod's tolerations.

pkg/apis/trainer/v1alpha1/trainjob_types.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,13 @@ type PodSpecPatch struct {
432432
// +kubebuilder:validation:MaxItems=16
433433
// +optional
434434
SchedulingGates []corev1.PodSchedulingGate `json:"schedulingGates,omitempty"`
435+
436+
// terminationGracePeriodSeconds patches the termination grace period for Pods
437+
// in the target job templates. This allows users to configure sufficient time
438+
// for checkpoint saving on TrainJob completion or node drain.
439+
// +kubebuilder:validation:Minimum=0
440+
// +optional
441+
TerminationGracePeriodSeconds *int64 `json:"terminationGracePeriodSeconds,omitempty"`
435442
}
436443

437444
// ContainerPatch represents parameters that can be patched using PodSpecPatch.

pkg/apis/trainer/v1alpha1/zz_generated.deepcopy.go

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/apis/trainer/v1alpha1/zz_generated.openapi.go

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/client/applyconfiguration/trainer/v1alpha1/podspecpatch.go

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/util/testing/wrapper.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,15 @@ func (j *JobSetWrapper) SchedulingGates(rJobName string, schedulingGates ...core
338338
return j
339339
}
340340

341+
func (j *JobSetWrapper) TerminationGracePeriodSeconds(rJobName string, seconds int64) *JobSetWrapper {
342+
for i, rJob := range j.Spec.ReplicatedJobs {
343+
if rJob.Name == rJobName {
344+
j.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec.TerminationGracePeriodSeconds = &seconds
345+
}
346+
}
347+
return j
348+
}
349+
341350
func (j *JobSetWrapper) ImagePullSecrets(rJobName string, imagePullSecrets ...corev1.LocalObjectReference) *JobSetWrapper {
342351
for i, rJob := range j.Spec.ReplicatedJobs {
343352
if rJob.Name == rJobName {

test/integration/controller/trainjob_controller_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,79 @@ var _ = ginkgo.Describe("TrainJob controller", ginkgo.Ordered, func() {
301301
g.Expect(k8sClient.Update(ctx, trainJob)).Should(testingutil.BeInvalidError())
302302
}, util.Timeout, util.Interval).Should(gomega.Succeed())
303303
})
304+
ginkgo.It("Should propagate terminationGracePeriodSeconds from RuntimePatches to JobSet pods", func() {
305+
ginkgo.By("Creating a TrainingRuntime and TrainJob with terminationGracePeriodSeconds patch")
306+
gracePeriodRuntime := testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha-grace").
307+
RuntimeSpec(
308+
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(ns.Name, "alpha-grace").Spec).
309+
WithMLPolicy(
310+
testingutil.MakeMLPolicyWrapper().
311+
WithNumNodes(1).
312+
Obj(),
313+
).
314+
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
315+
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
316+
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
317+
Obj()).
318+
Obj()
319+
gomega.Expect(k8sClient.Create(ctx, gracePeriodRuntime)).Should(gomega.Succeed())
320+
gomega.Eventually(func(g gomega.Gomega) {
321+
g.Expect(k8sClient.Get(ctx, client.ObjectKeyFromObject(gracePeriodRuntime), gracePeriodRuntime)).Should(gomega.Succeed())
322+
}, util.Timeout, util.Interval).Should(gomega.Succeed())
323+
324+
gracePeriod := int64(300)
325+
graceJob := testingutil.MakeTrainJobWrapper(ns.Name, "grace-period-job").
326+
Suspend(true).
327+
RuntimeRef(trainer.GroupVersion.WithKind(trainer.TrainingRuntimeKind), "alpha-grace").
328+
RuntimePatches([]trainer.RuntimePatch{{
329+
Manager: "test.io/manager",
330+
TrainingRuntimeSpec: &trainer.TrainingRuntimeSpecPatch{
331+
Template: &trainer.JobSetTemplatePatch{
332+
Spec: &trainer.JobSetSpecPatch{
333+
ReplicatedJobs: []trainer.ReplicatedJobPatch{{
334+
Name: constants.Node,
335+
Template: &trainer.JobTemplatePatch{
336+
Spec: &trainer.JobSpecPatch{
337+
Template: &trainer.PodTemplatePatch{
338+
Spec: &trainer.PodSpecPatch{
339+
TerminationGracePeriodSeconds: &gracePeriod,
340+
},
341+
},
342+
},
343+
},
344+
}},
345+
},
346+
},
347+
},
348+
}}).
349+
Trainer(
350+
testingutil.MakeTrainJobTrainerWrapper().
351+
Container("test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
352+
Obj()).
353+
Obj()
354+
graceJobKey := client.ObjectKeyFromObject(graceJob)
355+
gomega.Expect(k8sClient.Create(ctx, graceJob)).Should(gomega.Succeed())
356+
357+
ginkgo.By("Checking that JobSet node pods have terminationGracePeriodSeconds set to 300")
358+
gomega.Eventually(func(g gomega.Gomega) {
359+
jobSet := &jobsetv1alpha2.JobSet{}
360+
g.Expect(k8sClient.Get(ctx, graceJobKey, jobSet)).Should(gomega.Succeed())
361+
g.Expect(jobSet).Should(gomega.BeComparableTo(
362+
testingutil.MakeJobSetWrapper(ns.Name, graceJobKey.Name).
363+
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), graceJobKey.Name, string(graceJob.UID)).
364+
Suspend(true).
365+
Replicas(1, constants.Node, constants.DatasetInitializer, constants.ModelInitializer).
366+
Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer).
367+
Completions(1, constants.DatasetInitializer, constants.ModelInitializer).
368+
NumNodes(1).
369+
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
370+
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
371+
Container(constants.Node, constants.Node, "test:trainjob", []string{"trainjob"}, []string{"trainjob"}, resRequests).
372+
TerminationGracePeriodSeconds(constants.Node, gracePeriod).
373+
Obj(),
374+
util.IgnoreObjectMetadata))
375+
}, util.Timeout, util.Interval).Should(gomega.Succeed())
376+
})
304377
})
305378

306379
ginkgo.Context("Integration tests for the Torch Runtime", func() {

0 commit comments

Comments
 (0)